Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,73 @@
<img src="assets/origami_logo.jpg" style="width: 100%; height: auto;">
</p>

# ORiGAMi - Object Representation through Generative Autoregressive Modelling
# ORiGAMi - Object Representation through Generative Autoregressive Modelling

<p align="center">
| <a href=""><b>ORiGAMi Paper</b></a> | <a href=""><b>ORiGAMi Blog Post</b></a> |
| <a href="https://arxiv.org/abs/2412.17348"><b>ORiGAMi Paper on Arxiv</b></a> |
</p>

## Disclaimer

Please note: This tool is not officially supported or endorsed by MongoDB, Inc. The code is released for use "AS IS" without any warranties of any kind, including, but not limited to its installation, use, or performance. Do not run this tool against critical production systems.

## Overview

ORiGAMi is a transformer-based Machine Learning model to learn directly from semi-structured data such as JSON
or Python dictionaries.
or Python dictionaries.

Typically, when working with semi-structured data in a Machine Learning context, the data needs to be flattened
into a tabular form first. This flattening can be lossy, especially in the presence of arrays and nested objects, and often requires domain expertise to extract meaningful higher-order features from the raw data. This feature extraction step is manual, slow and expensive and doesn't scale well.
into a tabular form first. This flattening can be lossy, especially in the presence of arrays and nested objects, and often requires domain expertise to extract meaningful higher-order features from the raw data. This feature extraction step is manual, slow and expensive and doesn't scale well.

ORiGAMi is a transformer model and follows the trend of many other deep learning models by operating directly on the raw data and discovering meaningful features itself. Preprocessing is fully automated (apart from some hyper-parameters that can improve the model performance).

### Use Cases

Once an ORiGAMi model is trained on a collection of JSON objects, it can be used in several ways:

1. **Prediction**: ORiGAMi models can predict the value for any key of the dataset. This is different to typical discriminative models such as Logistic Regression or Random Forests, which have to be trained with a particular target key in mind. ORiGAMi is a generative model trained in order-agnostic fashion, and a single trained model can predict any target, given any subset of key/value pairs as input.
2. **Autocompletion**: ORiGAMi can auto-complete partial objects based on the probabilities it has learned from the training data. This also allows it to predict complex values such as nested objects or arrays.
1. **Prediction**: ORiGAMi models can predict the value for any key of the dataset. This is different to typical discriminative models such as Logistic Regression or Random Forests, which have to be trained with a particular target key in mind. ORiGAMi is a generative model trained in order-agnostic fashion, and a single trained model can predict any target, given any subset of key/value pairs as input.
2. **Autocompletion**: ORiGAMi can auto-complete partial objects based on the probabilities it has learned from the training data, by iteratively sampling next tokens. This also allows it to predict complex multi-token values such as nested objects or arrays.
3. **Generation**: ORiGAMi can generate synthetic mock data by sampling from the distribution it has learned from the training data.
4. **Embeddings**: As a deep neural network, ORiGAMi creates contextualized embeddings which can be extracted at the last hidden layer. These embeddings represent the objects in latent space and can be used as inputs to other ML algorithms, for data visualization or similarity search.
<!-- 4. **Embeddings**: As a deep neural network, ORiGAMi creates contextualized embeddings which can be extracted at the last hidden layer. These embeddings represent the objects in latent space and can be used as inputs to other ML algorithms, for data visualization or similarity search. -->

Check out the Juypter notebooks under [`./notebooks/`](./notebooks/) for examples for each of these use cases.



## Installation

To install ORiGAMi, use
ORiGAMi requires Python version 3.10 or higher. We recommend using a virtual environment, such as
Python's native [`venv`](https://docs.python.org/3/library/venv.html).

To install ORiGAMi with `pip`, use

```shell
pip install origami-ml
```


## Usage

ORiGAMi comes with a command line interface (CLI) and a Python SDK.

ORiGAMi comes with a command line interface (CLI) and a Python SDK.

### Usage from the Command Line

To train a model, use the `origami train` command. ORiGAMi works well with MongoDB. For example, to train a model on the `shop.orders` collection on a locally running MongoDB instance on standard port 27017, use the following command:
The CLI allows to train a model and make predictions and generate synthetic data from a trained model. After installation, run `origami` from your shell to see an overview of available commands.

Help for specific commands is available with `origami <command> --help`, where `<command>` is one of `train`, `predict`, `generate`.

#### Model Training

To train a model, use the `origami train` command. ORiGAMi works well with MongoDB. For example, to train a model on the `shop.orders` collection on a locally running MongoDB instance on standard port 27017, use the following command:

```
origami train "mongodb://localhost:27017" --source-db shop --source-coll orders
```

#### Making Predictions

...TBD...

#### Generating Synthetic Data

...TBD...

### Usage with Python

Expand All @@ -64,3 +78,6 @@ origami train "mongodb://localhost:27017" --source-db shop --source-coll orders
from origami.model import ORIGAMI
```

## Experiment Reproduction

This code is released alongside our paper, which can be found on Arxiv: [ORIGAMI: A generative transformer architecture for predictions from semi-structured data](https://arxiv.org/abs/2412.17348). To reproduce the experiments in the paper, see the instructions in the [`./experiments/`](./experiments/) directory.
33 changes: 15 additions & 18 deletions experiments/prediction/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,58 +52,57 @@ under the field path `task.cross_validation`.
guild run <model>:all dataset=tictactoe cross_val=catalog
``` -->


# Reproducing the results from our paper

We use the open source library [guild.ai](https://guild.ai) for experiment management and result tracking.

We use the open source library [guild.ai](https://guild.ai) for experiment management and result tracking.

### Datasets

We bundled all datasets used in the paper in a convenient [MongoDB dump file](). To reproduce the results, first
you need MongoDB installed on your system (or a remote server). Then, download the dump file and restore it
into your database.
you need MongoDB installed on your system (or a remote server). Then, download the dump file, unzip it, and restore it into your MongoDB instance:

```
mongorestore dump/
```

This assumes your `mongod` server is running on `localhost`, default port 27017 and without authentication.
If your setup varies, consult the [documentation](https://www.mongodb.com/docs/database-tools/mongorestore/)
for `mongorestore` on how to restore the data.
This assumes your `mongod` server is running on `localhost` on default port 27017 and without authentication. If your setup varies, consult the [documentation](https://www.mongodb.com/docs/database-tools/mongorestore/) for `mongorestore` on how to restore the data.

If your database setup (URI, port, authentication) differs, also make sure to update the [`.env.local`](.env.local) file in this directory accordingly.

### Hyper-parameter tuning

To conduct a hyper-parameter search for a model, use the following command:

```
NUMPY_EXPERIMENTAL_DTYPE_API=1 guild run <model>:hyperopt dataset=<dataset> --optimizer random --max-trials <num>
```

This will evaluate `<num>` random combinations for model `<model>` on a 5-fold cross-validation for the dataset `<dataset>`:

- `<model>` is the model name, choose from `origami`, `logreg`, `rf`, `xgboost`, `lightgbm`.
- `<dataset>` is the dataset config name under `./datasets`. For example `json2vec-car` refers to the file `json2vec-car.yml` file.
- `<model>` is the model name, choose from `origami`, `logreg`, `rf`, `xgboost`, `lightgbm`.
- `<dataset>` is the dataset config filename under [`./datasets`](./datasets/). For example `json2vec-car` refers to the file `json2vec-car.yml` file.

Each parameter combination is executed as a separate guild run. To see the best parameters, you can use

```
guild compare -Fo <model>:hyperopt -F"dataset=<dataset>" -u
```

Alternatively you can provide a `--label <label>` as part of the run command and filter the comparison like so:

```
guild compare -Fl <label> -u
```

Search for the column `test_acc_mean` and sort in descending order (press `S`). Take note of the run ID (an 8-digit hash) of the best run (first column).

To retrieve the flags of this particular run, use:

```
guild runs info <run-id>
```


### Running a configuration
### Running a hyperparameter configuration

To run a particular parameter configuration on a dataset, use the following command:

Expand All @@ -112,17 +111,17 @@ guild run <model>:all dataset=<dataset> <param1>=<value1> <param2=value>
```

- `<model>` is the model name, choose from `origami`, `logreg`, `rf`, `xgboost`, `lightgbm`.
- `<dataset>` is the dataset config name under `./datasets`. For example `json2vec-car` refers to the file `json2vec-car.yml` file.
- parameters are provided as `<param>=<value>`. For example, to change the number of layers in the model to 6, use `model.n_layer=6`. All available parameters can be found in the [`./flags.yaml`](./flags.yml) file.
- `<dataset>` is the dataset config name under `./datasets`. For example `json2vec-car` refers to the file `json2vec-car.yml` file.
- parameters are provided as `<param>=<value>`. For example, to change the number of layers in the model to 6, use `model.n_layer=6`. All available parameters can be found in the [`./flags.yaml`](./flags.yml) file.

### Best ORiGAMi parameters for each dataset

For convenience, we list the invocations with the best hyperparameters we provided in the paper.
For convenience, we list the invocations with the best hyperparameters we provided in the paper.

#### automobile dataset

```
guild run origami:all dataset=json2vec-automobile model.n_embd=48 model.n_head=4 model.n_layer=2 pipeline.sequence_order=ORDERED pipeline.upscale=1 train.batch_size=50 train.n_batches=20000 cross_val=5-fold
guild run origami:all dataset=json2vec-automobile model.n_embd=160 model.n_head=8 model.n_layer=5 pipeline.sequence_order=SHUFFLED pipeline.n_bins=10 pipeline.upscale=400 train.batch_size=10 train.n_batches=10000 train.learning_rate=4e-5 cross_val=5-fold
```

#### bank dataset
Expand Down Expand Up @@ -155,14 +154,12 @@ guild run origami:all dataset=json2vec-mushroom model.n_embd=64 model.n_head=4 m
guild run origami:all dataset=json2vec-nursery model.n_embd=64 model.n_head=4 model.n_layer=4 pipeline.sequence_order=SHUFFLED pipeline.upscale=40 train.batch_size=100 train.n_batches=10000 cross_val=5-fold
```


#### seismic dataset

```
guild run origami:all dataset=json2vec-seismic model.n_embd=16 model.n_head=4 model.n_layer=4 pipeline.sequence_order=SHUFFLED pipeline.upscale=1000 train.batch_size=100 train.n_batches=10000 cross_val=5-fold
```


#### student dataset

```
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ click==8.1.7
click-option-group==0.5.6
guildai==0.9.0
matplotlib==3.9.2
mdbrtools==0.1.0
mdbrtools==0.1.1
numpy==1.26.4
omegaconf==2.3.0
pandas==2.2.3
Expand Down