diff --git a/CHANGES.md b/CHANGES.md index 1741016b..5b2a79f3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file. - Talon: Allow forcing column types in Talon JSON loader (#104, @nirnayroy) - Nx: Update comparison and conditional operations to use boolean tensors (#54, @nirnayroy) - Kaun: Split CSV loader into `from_csv` and `from_csv_with_labels` to retain labels when requested (#114, @Satarupa22-SD). +- Saga: Fix Unigram `token_to_id`/`id_to_token` vocabulary lookups (#117, @RidwanAdebosin) ## [1.0.0~alpha1] - 2025-10-02 diff --git a/saga/lib/tokenizers/models.ml b/saga/lib/tokenizers/models.ml index ae6de5f1..77743b3d 100644 --- a/saga/lib/tokenizers/models.ml +++ b/saga/lib/tokenizers/models.ml @@ -26,7 +26,11 @@ type wordpiece_model = { type wordlevel_model = { vocab : (string, int) Hashtbl.t; unk_token : string } (** WordLevel model configuration *) -type unigram_model = { vocab : (string * float) list } +type unigram_model = { + vocab : (string * float) list; + token_map : (string, int) Hashtbl.t; + tokens : string array; +} (** Unigram model configuration *) (** Main model type *) @@ -160,9 +164,8 @@ let token_to_id model token = try Some (Hashtbl.find vocab token) with Not_found -> None) | WordLevel { vocab; _ } -> ( try Some (Hashtbl.find vocab token) with Not_found -> None) - | Unigram { vocab } -> - List.find_opt (fun (s, _) -> s = token) vocab |> Option.map (fun _ -> 0) -(* TODO: Proper implementation *) + | Unigram { token_map; _ } -> ( + try Some (Hashtbl.find token_map token) with Not_found -> None) (** Get token from ID *) let id_to_token model id = @@ -185,7 +188,8 @@ let id_to_token model id = Hashtbl.fold (fun token tid acc -> if tid = id then Some token else acc) vocab None - | Unigram _ -> None (* TODO: Proper implementation *) + | Unigram { tokens; _ } -> + if id >= 0 && id < Array.length tokens then Some tokens.(id) else None (** Get vocabulary *) let get_vocab model = @@ -196,7 +200,7 @@ let get_vocab model = Hashtbl.fold (fun token id acc -> (token, id) :: acc) vocab [] | WordLevel { vocab; _ } -> Hashtbl.fold (fun token id acc -> (token, id) :: acc) vocab [] - | Unigram { vocab } -> List.mapi (fun i (token, _) -> (token, i)) vocab + | Unigram { vocab; _ } -> List.mapi (fun i (token, _) -> (token, i)) vocab (** Get vocabulary size *) let get_vocab_size model = @@ -204,7 +208,7 @@ let get_vocab_size model = | BPE { vocab; _ } -> Hashtbl.length vocab | WordPiece { vocab; _ } -> Hashtbl.length vocab | WordLevel { vocab; _ } -> Hashtbl.length vocab - | Unigram { vocab } -> List.length vocab + | Unigram { tokens; _ } -> Array.length tokens (** Save model *) let save model ~folder ?(prefix = "") () = @@ -280,6 +284,12 @@ let word_level ?(vocab = []) ?(unk_token = "") () = List.iter (fun (token, id) -> Hashtbl.add vocab_tbl token id) vocab; WordLevel { vocab = vocab_tbl; unk_token } +let build_unigram_lookup vocab = + let token_map = Hashtbl.create (List.length vocab) in + List.iteri (fun i (token, _score) -> Hashtbl.replace token_map token i) vocab; + let tokens = Array.of_list (List.map fst vocab) in + (token_map, tokens) + let unigram ?(vocab = []) ?(unk_token = "") ?(byte_fallback = false) ?(max_piece_length = 16) ?(n_sub_iterations = 2) ?(shrinking_factor = 0.75) () = @@ -290,7 +300,9 @@ let unigram ?(vocab = []) ?(unk_token = "") ?(byte_fallback = false) n_sub_iterations, shrinking_factor ) in - Unigram { vocab } + let token_map, tokens = build_unigram_lookup vocab in + + Unigram { vocab; token_map; tokens } let chars () = (* Character-level tokenization - create a special marker *) diff --git a/saga/lib/tokenizers/models.mli b/saga/lib/tokenizers/models.mli index 66b2cbca..2dd79649 100644 --- a/saga/lib/tokenizers/models.mli +++ b/saga/lib/tokenizers/models.mli @@ -23,7 +23,12 @@ type wordpiece_model = { } type wordlevel_model = { vocab : (string, int) Hashtbl.t; unk_token : string } -type unigram_model = { vocab : (string * float) list } + +type unigram_model = { + vocab : (string * float) list; + token_map : (string, int) Hashtbl.t; + tokens : string array; +} (** Main model type *) type t = diff --git a/saga/lib/tokenizers/trainers.ml b/saga/lib/tokenizers/trainers.ml index 126c103e..f76db828 100644 --- a/saga/lib/tokenizers/trainers.ml +++ b/saga/lib/tokenizers/trainers.ml @@ -285,7 +285,7 @@ let train_unigram (_config : unigram_config) _lines _existing_model = (fun w c acc -> (w, float_of_int c /. max 1.0 total) :: acc) freq [] in - let model = Models.Unigram { vocab } in + let model = Models.unigram ~vocab () in { model; special_tokens = [] } (** Main training function *) diff --git a/saga/test/test_tokenization.ml b/saga/test/test_tokenization.ml index 1317ef49..6f82ede3 100644 --- a/saga/test/test_tokenization.ml +++ b/saga/test/test_tokenization.ml @@ -109,6 +109,68 @@ let test_tokenize_regex_no_match () = let tokens = Array.to_list (Encoding.get_tokens encoding) in check (list string) "regex no match" [] tokens +(* ───── Unigram Model Tests ───── *) + +(* Round-trip lookups *) +let test_unigram_roundtrip () = + let tokens = [ "hello"; "world"; "test" ] in + let vocab = List.map (fun token -> (token, 0.0)) tokens in + let model = Models.unigram ~vocab () in + List.iteri + (fun expected_id token -> + check (option int) + (Printf.sprintf "token_to_id '%s'" token) + (Some expected_id) + (Models.token_to_id model token); + check (option string) + (Printf.sprintf "id_to_token %d" expected_id) + (Some token) + (Models.id_to_token model expected_id)) + tokens + +(* token_to_id - out of vocab *) +let test_unigram_token_to_id_oov () = + let model = Models.unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in + check (option int) "token_to_id out-of-vocab" None + (Models.token_to_id model "missing") + +(* id_to_token - out of bounds *) +let test_unigram_id_to_token_oob () = + let model = Models.unigram ~vocab:[ ("hello", 0.0); ("world", 0.0) ] () in + check (option string) "id_to_token negative" None + (Models.id_to_token model (-1)); + check (option string) "id_to_token out of bounds" None + (Models.id_to_token model 10) + +(* Test empty vocabulary *) +let test_unigram_empty_vocab () = + let model = Models.unigram ~vocab:[] () in + check (option int) "empty vocab token_to_id" None + (Models.token_to_id model "test"); + check (option string) "empty vocab id_to_token" None + (Models.id_to_token model 0) + +(* Test special characters and unicode *) +let test_unigram_special_tokens () = + let model = + Models.unigram + ~vocab: + [ + ("", 0.0); + ("", 0.0); + ("", 0.0); + ("▁hello", 0.0); + ("世界", 0.0); + ] + () + in + check (option int) "special " (Some 0) (Models.token_to_id model ""); + check (option int) "special " (Some 1) (Models.token_to_id model ""); + check (option int) "sentencepiece token" (Some 3) + (Models.token_to_id model "▁hello"); + check (option int) "unicode token" (Some 4) (Models.token_to_id model "世界"); + check (option string) "id to unicode" (Some "世界") (Models.id_to_token model 4) + (* ───── Edge Cases ───── *) let test_tokenize_long_text () = @@ -157,6 +219,14 @@ let tokenization_tests = test_case "tokenize repeated punctuation" `Quick test_tokenize_repeated_punctuation; test_case "tokenize mixed whitespace" `Quick test_tokenize_mixed_whitespace; + (* Unigram model tests *) + test_case "unigram roundtrip" `Quick test_unigram_roundtrip; + test_case "unigram token_to_id out-of-vocab" `Quick + test_unigram_token_to_id_oov; + test_case "unigram id_to_token out-of-bounds" `Quick + test_unigram_id_to_token_oob; + test_case "unigram empty vocab" `Quick test_unigram_empty_vocab; + test_case "unigram special tokens" `Quick test_unigram_special_tokens; ] let () =