Skip to content

Commit

Permalink
Added Tokens (#93)
Browse files Browse the repository at this point in the history
* Split by regexp with capture groups

The other split helpers we have don't work for capture groups.
We had to resort to raw `NSRegularExpression`s

* Build added tokens split regexp, shortcut before pre-tokenization

* Update PreTokenizers so Metaspace can conditionally act

* Create LlamaPreTrainedTokenizer subclass

We need some custom behaviour that's not in the config :(

* Rename test

* Replace with enum for future extensibility
  • Loading branch information
pcuenca authored Apr 28, 2024
1 parent 5e02089 commit 0a606f5
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 26 deletions.
83 changes: 62 additions & 21 deletions Sources/Tokenizers/PreTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,33 @@
import Foundation
import Hub

public enum PreTokenizerOption: String {
case firstSection
}

public typealias PreTokenizerOptions = Set<PreTokenizerOption>

public protocol PreTokenizer {
func preTokenize(text: String) -> [String]
func preTokenize(texts: [String]) -> [String]
func callAsFunction(texts: [String]) -> [String]
func callAsFunction(text: String) -> [String]
func preTokenize(text: String, options: PreTokenizerOptions) -> [String]
func preTokenize(texts: [String], options: PreTokenizerOptions) -> [String]
func callAsFunction(texts: [String], options: PreTokenizerOptions) -> [String]
func callAsFunction(text: String, options: PreTokenizerOptions) -> [String]

init(config: Config)
}

extension PreTokenizer {
func preTokenize(texts: [String]) -> [String] {
texts.flatMap { preTokenize(text: $0) }
func preTokenize(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
texts.flatMap { preTokenize(text: $0, options: options) }
}

func callAsFunction(texts: [String]) -> [String] {
return preTokenize(texts: texts)
func callAsFunction(texts: [String], options: PreTokenizerOptions = [.firstSection]) -> [String] {
return preTokenize(texts: texts, options: options)
}

func callAsFunction(text: String) -> [String] {
return preTokenize(text: text)
func callAsFunction(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
return preTokenize(text: text, options: options)
}

}

enum PreTokenizerType: String {
Expand Down Expand Up @@ -71,9 +76,9 @@ class PreTokenizerSequence: PreTokenizer {
preTokenizers = configs.compactMap { PreTokenizerFactory.fromConfig(config: $0) }
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
preTokenizers.reduce([text]) { current, preTokenizer in
preTokenizer(texts: current)
preTokenizer(texts: current, options: options)
}
}
}
Expand All @@ -85,7 +90,7 @@ class WhitespacePreTokenizer: PreTokenizer {
re = #"\S+"#
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
return text.ranges(of: re).map { String(text[$0]) }
}
}
Expand Down Expand Up @@ -125,7 +130,7 @@ class MetaspacePreTokenizer: PreTokenizer {

// https://github.com/huggingface/tokenizers/blob/accd0650b802f2180df40ef1def3bce32156688e/tokenizers/src/pre_tokenizers/metaspace.rs#L114
// https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/tokenizers.js#L2153
func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
let normalized = text.replacingOccurrences(of: " ", with: stringReplacement)

// We add a prefix space if:
Expand All @@ -141,7 +146,7 @@ class MetaspacePreTokenizer: PreTokenizer {
if prependScheme == .always {
prepend = stringReplacement
}
if prependScheme == .first /* && first_section */ {
if prependScheme == .first && options.contains(.firstSection) {
prepend = stringReplacement
}
}
Expand All @@ -164,7 +169,7 @@ class ByteLevelPreTokenizer: PreTokenizer {
useRegex = config.useRegex?.boolValue ?? true
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
// Split on whitespace and punctuation
let tokens = useRegex ? text.ranges(of: RE).map({ String(text[$0]) }) : [text]
return tokens.map { token in
Expand All @@ -186,7 +191,7 @@ class PunctuationPreTokenizer: PreTokenizer {
re = "[^\(PUNCTUATION_REGEX)]+|[\(PUNCTUATION_REGEX)]+"
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
// Ref: https://github.com/xenova/transformers.js/blob/27920d84831e323275b38f0b5186644b7936e1a2/src/tokenizers.js#L1138
return text.ranges(of: re).map { String(text[$0]) }
}
Expand All @@ -200,7 +205,7 @@ class DigitsPreTokenizer: PreTokenizer {
re = "[^\\d]+|\\d\(individualDigits ? "" : "+")"
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
return text.ranges(of: re).map { String(text[$0]) }
}
}
Expand All @@ -214,7 +219,7 @@ class SplitPreTokenizer: PreTokenizer {
invert = config.invert?.boolValue ?? false
}

func preTokenize(text: String) -> [String] {
func preTokenize(text: String, options: PreTokenizerOptions = [.firstSection]) -> [String] {
guard let pattern = pattern else { return [text] }
return pattern.split(text, invert: invert)
}
Expand Down Expand Up @@ -248,7 +253,7 @@ extension StringSplitPattern {
}
}

extension String {
public extension String {
func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range<Index>] {
var result: [Range<Index>] = []
var start = startIndex
Expand Down Expand Up @@ -277,6 +282,42 @@ extension String {
return result
}

/// This version supports capture groups, wheres the one above doesn't
func split(by captureRegex: NSRegularExpression) -> [String] {
// Find the matching capture groups
let selfRange = NSRange(startIndex..<endIndex, in: self)
let matches = captureRegex.matches(in: self, options: [], range: selfRange)

if matches.first == nil { return [self] }

var result: [String] = []
var start = startIndex
for match in matches {
// Append prefix before matched separator
let prefixEnd = index(startIndex, offsetBy: match.range.lowerBound)
if start < prefixEnd {
result.append(String(self[start..<prefixEnd]))
}
start = index(startIndex, offsetBy: match.range.upperBound)

// Append separator, supporting capture groups
for r in (0..<match.numberOfRanges).reversed() {
let matchRange = match.range(at: r)
if let sepRange = Range(matchRange, in:self) {
result.append(String(self[sepRange]))
break
}
}
}

// Append remaining suffix
let beginningOfEnd = index(startIndex, offsetBy: matches.last!.range.upperBound)
if beginningOfEnd < endIndex {
result.append(String(self[beginningOfEnd...]))
}

return result
}
}

public enum SplitDelimiterBehavior {
Expand Down
76 changes: 71 additions & 5 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public class PreTrainedTokenizer: Tokenizer {

private let addedTokens: Set<String>
private let specialTokens: [String: Int]
private let addedTokensRegex: NSRegularExpression?

private let preTokenizer: PreTokenizer?
private let normalizer: Normalizer?
Expand All @@ -161,6 +162,16 @@ public class PreTrainedTokenizer: Tokenizer {
specialTokens[content] = id
}
}

let addedTokensRegexString = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
guard let content = addedToken.content?.stringValue else { return nil }
let prefix = (addedToken.lstrip?.boolValue ?? false ? #"\s*"# : "")
let suffix = (addedToken.rstrip?.boolValue ?? false ? #"\s*"# : "")
let token = NSRegularExpression.escapedPattern(for: content)
return "\(prefix)(\(token))\(suffix)"
}.joined(separator: "|")
addedTokensRegex = try? NSRegularExpression(pattern: addedTokensRegexString, options: [])

// TODO: specialTokens are stored but never used
self.specialTokens = specialTokens
self.addedTokens = Set(addedTokens.keys)
Expand All @@ -174,9 +185,9 @@ public class PreTrainedTokenizer: Tokenizer {
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

func preTokenize(_ text: String) -> [String] {
func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] {
guard let preTokenizer = preTokenizer else { return [text] }
return preTokenizer(text: text)
return preTokenizer(text: text, options: options)
}

func normalize(_ text: String) -> String {
Expand Down Expand Up @@ -211,7 +222,17 @@ public class PreTrainedTokenizer: Tokenizer {
}

public func tokenize(text: String) -> [String] {
preTokenize(normalize(text)).flatMap { model($0) }
// Take care of special tokens first
let sections: [String]
if let regex = self.addedTokensRegex {
sections = text.split(by: regex)
} else {
sections = [text]
}
return sections.enumerated().map { section, x in
if addedTokens.contains(x) { return [x] }
return preTokenize(normalize(x), options: section == 0 ? [.firstSection] : []).flatMap { model($0) }
}.flatMap { $0 }
}

/// Main entry point
Expand Down Expand Up @@ -241,9 +262,32 @@ public class PreTrainedTokenizer: Tokenizer {

public struct AutoTokenizer {}

struct PreTrainedTokenizerClasses {
/// Class overrides for custom behaviour
/// Not to be confused with the TokenizerModel classes defined in TokenizerModel
static let tokenizerClasses: [String : PreTrainedTokenizer.Type] = [
"LlamaTokenizer": LlamaPreTrainedTokenizer.self
]
}

extension AutoTokenizer {
static func tokenizerClass(for tokenizerConfig: Config) -> PreTrainedTokenizer.Type {
guard let tokenizerClassName = tokenizerConfig.tokenizerClass?.stringValue else {
return PreTrainedTokenizer.self
}

// Some tokenizer_class entries use a Fast suffix
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
if let tokenizerClass = PreTrainedTokenizerClasses.tokenizerClasses[tokenizerName] {
return tokenizerClass
}

return PreTrainedTokenizer.self
}

public static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer {
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
let tokenizerClass = tokenizerClass(for: tokenizerConfig)
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
Expand All @@ -254,7 +298,7 @@ extension AutoTokenizer {
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
Expand All @@ -281,3 +325,25 @@ class CodeLlamaTokenizer: BPETokenizer {}
class CohereTokenizer : BPETokenizer {}

class T5Tokenizer : UnigramTokenizer {}


// MARK: - PreTrainedTokenizer classes

let sentencePieceUnderline = ""

// See https://github.com/xenova/transformers.js/blob/1a9964fb09b8f54fcbeac46dc6aae8d76795809d/src/tokenizers.js#L3203 for these exceptions
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
let isLegacy: Bool

required init(tokenizerConfig: Config, tokenizerData: Config) throws {
isLegacy = tokenizerConfig.legacy?.boolValue ?? true
var configDictionary = tokenizerData.dictionary
if !isLegacy {
configDictionary.removeValue(forKey: "normalizer")
configDictionary["pre_tokenizer"] = ["type": "Metaspace", "replacement": sentencePieceUnderline, "add_prefix_space": true, "prepend_scheme": "first"]
}
let updatedData = Config(configDictionary)

try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
}
}
67 changes: 67 additions & 0 deletions Tests/TokenizersTests/AddedTokensTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//
// AddedTokensTests.swift
//
//
// Created by Pedro Cuenca on 20240426.
//

import XCTest
import Tokenizers
import Hub

class AddedTokensTests: XCTestCase {
func testPhiAddedTokens() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-128k-instruct-4bit")
let inputIds = tokenizer("This is the <|end|>. My only friend, the <|end|>")
XCTAssertEqual(inputIds, [1, 910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])

let decoded = tokenizer.decode(tokens: inputIds)
XCTAssertEqual(decoded, "<s> This is the <|end|>. My only friend, the <|end|>")
}

func testSplitWithCaptureGroups() {
let addedTokensRegexp = #"(<\|end\|>)\s*|(<\|raw\|>)\s*"#
let captureRegex = try! NSRegularExpression(pattern: addedTokensRegexp, options: [])

XCTAssertEqual(
"eating <|raw|> meat <|end|> That's all".split(by: captureRegex),
["eating ", "<|raw|>", "meat ", "<|end|>", "That's all"]
)

XCTAssertEqual(
"<|raw|>".split(by: captureRegex),
["<|raw|>"]
)

XCTAssertEqual(
"This string doesn't have those separators".split(by: captureRegex),
["This string doesn't have those separators"]
)

XCTAssertEqual(
"start <|end|>".split(by: captureRegex),
["start ", "<|end|>"]
)

XCTAssertEqual(
"start <|end|> ".split(by: captureRegex),
["start ", "<|end|>"]
)

XCTAssertEqual(
"start <|end|> ".split(by: captureRegex),
["start ", "<|end|>"]
)

XCTAssertEqual(
"start <|end|> for real".split(by: captureRegex),
["start ", "<|end|>", "for real"]
)

XCTAssertEqual(
"<|raw|><|end|>".split(by: captureRegex),
["<|raw|>", "<|end|>"]
)

}
}

0 comments on commit 0a606f5

Please sign in to comment.