Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file.
- Talon: Allow forcing column types in Talon JSON loader (#104, @nirnayroy)
- Nx: Update comparison and conditional operations to use boolean tensors (#54, @nirnayroy)
- Kaun: Split CSV loader into `from_csv` and `from_csv_with_labels` to retain labels when requested (#114, @Satarupa22-SD).
- Saga: Fix Unigram `token_to_id`/`id_to_token` vocabulary lookups (#117, @RidwanAdebosin)

## [1.0.0~alpha1] - 2025-10-02

Expand Down
28 changes: 20 additions & 8 deletions saga/lib/tokenizers/models.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ type wordpiece_model = {
type wordlevel_model = { vocab : (string, int) Hashtbl.t; unk_token : string }
(** WordLevel model configuration *)

type unigram_model = { vocab : (string * float) list }
type unigram_model = {
vocab : (string * float) list;
token_map : (string, int) Hashtbl.t;
tokens : string array;
}
(** Unigram model configuration *)

(** Main model type *)
Expand Down Expand Up @@ -160,9 +164,8 @@ let token_to_id model token =
try Some (Hashtbl.find vocab token) with Not_found -> None)
| WordLevel { vocab; _ } -> (
try Some (Hashtbl.find vocab token) with Not_found -> None)
| Unigram { vocab } ->
List.find_opt (fun (s, _) -> s = token) vocab |> Option.map (fun _ -> 0)
(* TODO: Proper implementation *)
| Unigram { token_map; _ } -> (
try Some (Hashtbl.find token_map token) with Not_found -> None)

(** Get token from ID *)
let id_to_token model id =
Expand All @@ -185,7 +188,8 @@ let id_to_token model id =
Hashtbl.fold
(fun token tid acc -> if tid = id then Some token else acc)
vocab None
| Unigram _ -> None (* TODO: Proper implementation *)
| Unigram { tokens; _ } ->
if id >= 0 && id < Array.length tokens then Some tokens.(id) else None

(** Get vocabulary *)
let get_vocab model =
Expand All @@ -196,15 +200,15 @@ let get_vocab model =
Hashtbl.fold (fun token id acc -> (token, id) :: acc) vocab []
| WordLevel { vocab; _ } ->
Hashtbl.fold (fun token id acc -> (token, id) :: acc) vocab []
| Unigram { vocab } -> List.mapi (fun i (token, _) -> (token, i)) vocab
| Unigram { vocab; _ } -> List.mapi (fun i (token, _) -> (token, i)) vocab

(** Get vocabulary size *)
let get_vocab_size model =
match model with
| BPE { vocab; _ } -> Hashtbl.length vocab
| WordPiece { vocab; _ } -> Hashtbl.length vocab
| WordLevel { vocab; _ } -> Hashtbl.length vocab
| Unigram { vocab } -> List.length vocab
| Unigram { tokens; _ } -> Array.length tokens

(** Save model *)
let save model ~folder ?(prefix = "") () =
Expand Down Expand Up @@ -280,6 +284,12 @@ let word_level ?(vocab = []) ?(unk_token = "<unk>") () =
List.iter (fun (token, id) -> Hashtbl.add vocab_tbl token id) vocab;
WordLevel { vocab = vocab_tbl; unk_token }

let build_unigram_lookup vocab =
let token_map = Hashtbl.create (List.length vocab) in
List.iteri (fun i (token, _score) -> Hashtbl.replace token_map token i) vocab;
let tokens = Array.of_list (List.map fst vocab) in
(token_map, tokens)

let unigram ?(vocab = []) ?(unk_token = "<unk>") ?(byte_fallback = false)
?(max_piece_length = 16) ?(n_sub_iterations = 2) ?(shrinking_factor = 0.75)
() =
Expand All @@ -290,7 +300,9 @@ let unigram ?(vocab = []) ?(unk_token = "<unk>") ?(byte_fallback = false)
n_sub_iterations,
shrinking_factor )
in
Unigram { vocab }
let token_map, tokens = build_unigram_lookup vocab in

Unigram { vocab; token_map; tokens }

let chars () =
(* Character-level tokenization - create a special marker *)
Expand Down
7 changes: 6 additions & 1 deletion saga/lib/tokenizers/models.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ type wordpiece_model = {
}

type wordlevel_model = { vocab : (string, int) Hashtbl.t; unk_token : string }
type unigram_model = { vocab : (string * float) list }

type unigram_model = {
vocab : (string * float) list;
token_map : (string, int) Hashtbl.t;
tokens : string array;
}

(** Main model type *)
type t =
Expand Down
2 changes: 1 addition & 1 deletion saga/lib/tokenizers/trainers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ let train_unigram (_config : unigram_config) _lines _existing_model =
(fun w c acc -> (w, float_of_int c /. max 1.0 total) :: acc)
freq []
in
let model = Models.Unigram { vocab } in
let model = Models.unigram ~vocab () in
{ model; special_tokens = [] }

(** Main training function *)
Expand Down
70 changes: 70 additions & 0 deletions saga/test/test_tokenization.ml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,68 @@ let test_tokenize_regex_no_match () =
let tokens = Array.to_list (Encoding.get_tokens encoding) in
check (list string) "regex no match" [] tokens

(* ───── Unigram Model Tests ───── *)

(* Round-trip lookups *)
let test_unigram_roundtrip () =
let tokens = [ "hello"; "world"; "test" ] in
let vocab = List.map (fun token -> (token, 0.0)) tokens in
let model = Models.unigram ~vocab () in
List.iteri
(fun expected_id token ->
check (option int)
(Printf.sprintf "token_to_id '%s'" token)
(Some expected_id)
(Models.token_to_id model token);
check (option string)
(Printf.sprintf "id_to_token %d" expected_id)
(Some token)
(Models.id_to_token model expected_id))
tokens

(* token_to_id - out of vocab *)
let test_unigram_token_to_id_oov () =
let model = Models.unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in
check (option int) "token_to_id out-of-vocab" None
(Models.token_to_id model "missing")

(* id_to_token - out of bounds *)
let test_unigram_id_to_token_oob () =
let model = Models.unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in
check (option string) "id_to_token negative" None
(Models.id_to_token model (-1));
check (option string) "id_to_token out of bounds" None
(Models.id_to_token model 10)

(* Test empty vocabulary *)
let test_unigram_empty_vocab () =
let model = Models.unigram ~vocab:[] () in
check (option int) "empty vocab token_to_id" None
(Models.token_to_id model "test");
check (option string) "empty vocab id_to_token" None
(Models.id_to_token model 0)

(* Test special characters and unicode *)
let test_unigram_special_tokens () =
let model =
Models.unigram
~vocab:
[
("<unk>", 0.0);
("<s>", 0.0);
("</s>", 0.0);
("▁hello", 0.0);
("世界", 0.0);
]
()
in
check (option int) "special <unk>" (Some 0) (Models.token_to_id model "<unk>");
check (option int) "special <s>" (Some 1) (Models.token_to_id model "<s>");
check (option int) "sentencepiece token" (Some 3)
(Models.token_to_id model "▁hello");
check (option int) "unicode token" (Some 4) (Models.token_to_id model "世界");
check (option string) "id to unicode" (Some "世界") (Models.id_to_token model 4)

(* ───── Edge Cases ───── *)

let test_tokenize_long_text () =
Expand Down Expand Up @@ -157,6 +219,14 @@ let tokenization_tests =
test_case "tokenize repeated punctuation" `Quick
test_tokenize_repeated_punctuation;
test_case "tokenize mixed whitespace" `Quick test_tokenize_mixed_whitespace;
(* Unigram model tests *)
test_case "unigram roundtrip" `Quick test_unigram_roundtrip;
test_case "unigram token_to_id out-of-vocab" `Quick
test_unigram_token_to_id_oov;
test_case "unigram id_to_token out-of-bounds" `Quick
test_unigram_id_to_token_oob;
test_case "unigram empty vocab" `Quick test_unigram_empty_vocab;
test_case "unigram special tokens" `Quick test_unigram_special_tokens;
]

let () =
Expand Down
Loading