Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,115 changes: 535 additions & 580 deletions notebooks/example_origami_dungeons.ipynb

Large diffs are not rendered by default.

781 changes: 775 additions & 6 deletions notebooks/example_rf_dungeons.ipynb

Large diffs are not rendered by default.

11 changes: 2 additions & 9 deletions origami/preprocessing/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.utils.validation import check_is_fitted

from origami.utils.common import ArrayStart, Symbol, pad_trunc, walk_all_leaf_kvs
from origami.utils.common import ArrayStart, Symbol, pad_trunc, reorder_with_target_last, walk_all_leaf_kvs

from .encoder import StreamEncoder
from .utils import CAT_THRESHOLD, deepcopy_df, tokenize
Expand Down Expand Up @@ -122,19 +122,12 @@ class TargetFieldPipe(BasePipe):
def __init__(self, target_field: str):
self.target_field = target_field

def _move_target(self, doc: dict | OrderedDict) -> OrderedDict:
target = doc.pop(self.target_field, Symbol.UNKNOWN)
if isinstance(doc, dict):
doc = OrderedDict(doc)
doc[self.target_field] = target
return doc, target

def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
if "docs" not in X.columns:
raise ColumnMissingException("TargetFieldPipe requires column 'docs' in the DataFrame.")

X = deepcopy_df(X)
docs, targets = zip(*X["docs"].map(self._move_target))
docs, targets = zip(*X["docs"].map(lambda doc: reorder_with_target_last(doc, self.target_field)))
X["docs"] = docs
X["target"] = targets
return X
Expand Down
58 changes: 57 additions & 1 deletion origami/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Generator, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Generator, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -261,3 +261,59 @@ def progress_callback(model):
print_guild_scalars(**scalars)

return progress_callback


def parse_path(path: str) -> List[str]:
"""Split dot notation path into components."""
return path.split(".")


def get_value_at_path(d: dict, path: List[str]) -> Tuple[Any, bool]:
"""Retrieve value at specified path in nested dictionary.
Returns tuple of (value, found) where found is False if path doesn't exist."""
current = d
for component in path:
if not isinstance(current, dict) or component not in current:
return None, False
current = current[component]
return current, True


def reorder_with_target_last(d: dict, target_path: str) -> Tuple[OrderedDict, Any]:
"""
Reorder dictionary so target field appears last, maintaining nested structure.
If target field doesn't exist, returns (OrderedDict(d), Symbol.UNKNOWN).
"""
path_components = parse_path(target_path)
target_value, found = get_value_at_path(d, path_components)

if not found:
return OrderedDict(d), Symbol.UNKNOWN

def reorder_level(current_dict: dict, remaining_path: List[str]) -> OrderedDict:
if not remaining_path:
return OrderedDict(current_dict)

current_target = remaining_path[0]
result = OrderedDict()

# Add all non-target fields first
for k, v in current_dict.items():
if k != current_target:
result[k] = v if not isinstance(v, dict) else reorder_level(v, [])

# Add target field last
if current_target in current_dict:
target_dict = current_dict[current_target]
if len(remaining_path) > 1:
# If we have more path components, recurse with remaining path
result[current_target] = reorder_level(target_dict, remaining_path[1:])
else:
# If this is the final path component, add it last
result[current_target] = (
target_dict if not isinstance(target_dict, dict) else reorder_level(target_dict, [])
)

return result

return reorder_level(d, path_components), target_value
18 changes: 7 additions & 11 deletions tests/preprocessing/test_pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_supervised_target_pipe(self):
"docs": [
{"a": 1, "b": 2, "c": 3},
{"b": 1, "a": 3, "c": 2},
{"c": 2, "a": 1},
{"c": 2, "a": 1}, # Missing 'b'
]
}
)
Expand All @@ -153,16 +153,12 @@ def test_supervised_target_pipe(self):

self.assertIn("target", df.columns)

for i, doc in enumerate(df["docs"]):
self.assertIn("b", doc)
self.assertIsInstance(doc, OrderedDict)
self.assertEqual(list(doc.keys())[-1], "b")
self.assertEqual(doc["b"], df["target"][i])

self.assertEqual(df["docs"][2]["b"], Symbol.UNKNOWN)

# test that index is range
self.assertEqual(list(range(len(df))), list(df.index))
for i, (doc, target) in enumerate(zip(df["docs"], df["target"])):
if target == Symbol.UNKNOWN:
self.assertNotIn("b", doc)
else:
self.assertIn("b", doc)
self.assertEqual(doc["b"], target)


class TestDocTokenizerPipe(unittest.TestCase):
Expand Down
135 changes: 134 additions & 1 deletion tests/utils/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import unittest
from collections import OrderedDict

from origami.utils.common import flatten_docs, walk_all_leaf_kvs
from origami.utils.common import (
Symbol,
flatten_docs,
get_value_at_path,
parse_path,
reorder_with_target_last,
walk_all_leaf_kvs,
)


class TestWalkAllLeafKVs(unittest.TestCase):
Expand Down Expand Up @@ -107,3 +115,128 @@ def test_flatten_docs(self):
self.assertEqual(
flat, [{"foo": 1, "bar.baz": 2, "bar.buz": 3}, {"foo": 4, "bar.[0]": "test", "bar.[1].a": "b"}]
)


class TestDictionaryUtils(unittest.TestCase):
def test_simple_path(self):
self.assertEqual(parse_path("a"), ["a"])

def test_nested_path(self):
self.assertEqual(parse_path("a.b.c"), ["a", "b", "c"])

def test_simple_retrieval(self):
d = {"a": 1}
value, found = get_value_at_path(d, ["a"])
self.assertEqual(value, 1)
self.assertTrue(found)

def test_nested_retrieval(self):
d = {"a": {"b": {"c": 42}}}
value, found = get_value_at_path(d, ["a", "b", "c"])
self.assertEqual(value, 42)
self.assertTrue(found)

def test_missing_key(self):
d = {"a": 1}
value, found = get_value_at_path(d, ["b"])
self.assertFalse(found)
self.assertIsNone(value)

def test_nested_missing_key(self):
d = {"a": {"b": 1}}
value, found = get_value_at_path(d, ["a", "b", "c"])
self.assertFalse(found)
self.assertIsNone(value)

def test_simple_reorder(self):
input_dict = {"a": 1, "b": 2, "c": 3}
expected = OrderedDict([("a", 1), ("c", 3), ("b", 2)])
result, value = reorder_with_target_last(input_dict, "b")
self.assertEqual(dict(result), dict(expected))
self.assertEqual(value, 2)
self.assertEqual(list(result.keys())[-1], "b")

def test_nested_reorder(self):
input_dict = {"a": 1, "b": {"b1": True, "b2": False}, "c": "test"}
result, value = reorder_with_target_last(input_dict, "b.b1")
self.assertEqual(list(result.keys())[-1], "b")
self.assertEqual(list(result["b"].keys())[-1], "b1")
self.assertTrue(value)

def test_deep_nesting(self):
input_dict = {"l1": {"l2": {"l3": {"target": "value", "other": "other_value"}}}}
result, value = reorder_with_target_last(input_dict, "l1.l2.l3.target")
self.assertEqual(list(result["l1"]["l2"]["l3"].keys())[-1], "target")
self.assertEqual(value, "value")

def test_empty_dict(self):
result, value = reorder_with_target_last({}, "any_key")
self.assertEqual(value, Symbol.UNKNOWN)
self.assertEqual(dict(result), {})

def test_various_value_types(self):
input_dict = {
"str": "string",
"int": 42,
"bool": True,
"list": [1, 2, 3],
"dict": {"nested": "value"},
"none": None,
}
for key in input_dict:
result, value = reorder_with_target_last(input_dict, key)
self.assertEqual(list(result.keys())[-1], key)
self.assertEqual(value, input_dict[key])

def test_preserve_nested_structure(self):
input_dict = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": 4}
result, _ = reorder_with_target_last(input_dict, "a.b.c")
self.assertIsInstance(result["a"], OrderedDict)
self.assertIsInstance(result["a"]["b"], OrderedDict)

def test_target_already_last(self):
input_dict = OrderedDict([("a", 1), ("b", 2), ("target", 3)])
result, value = reorder_with_target_last(input_dict, "target")
self.assertEqual(list(result.keys()), ["a", "b", "target"])
self.assertEqual(value, 3)

def test_multiple_nested_fields(self):
input_dict = {"a": {"x": 1, "y": 2, "z": {"target": "value", "other1": "val1", "other2": "val2"}}, "b": "test"}
result, value = reorder_with_target_last(input_dict, "a.z.target")
self.assertEqual(list(result["a"]["z"].keys())[-1], "target")
self.assertEqual(value, "value")
self.assertIsInstance(result["a"], OrderedDict)
self.assertIsInstance(result["a"]["z"], OrderedDict)

def test_target_is_list(self):
input_dict = {"a": 1, "foo": [1, 2, 3], "b": 2}
result, value = reorder_with_target_last(input_dict, "foo")
self.assertEqual(list(result.keys())[-1], "foo")
self.assertEqual(value, [1, 2, 3])

def test_target_is_dict(self):
input_dict = {"a": 1, "foo": {"nested": "value", "other": 42}, "b": 2}
result, value = reorder_with_target_last(input_dict, "foo")
self.assertEqual(list(result.keys())[-1], "foo")
self.assertEqual(value, {"nested": "value", "other": 42})

def test_missing_field(self):
input_dict = {"a": 1, "b": {"b1": True, "b2": False}, "c": "test"}
# Test missing top-level field
result, value = reorder_with_target_last(input_dict, "nonexistent")
self.assertEqual(dict(result), input_dict) # Structure preserved
self.assertEqual(value, Symbol.UNKNOWN)

# Test missing nested field
result, value = reorder_with_target_last(input_dict, "b.nonexistent")
self.assertEqual(dict(result), input_dict) # Structure preserved
self.assertEqual(value, Symbol.UNKNOWN)

# Test path through non-dict value
result, value = reorder_with_target_last(input_dict, "a.something")
self.assertEqual(dict(result), input_dict) # Structure preserved
self.assertEqual(value, Symbol.UNKNOWN)


if __name__ == "__main__":
unittest.main()