Skip to content

Commit

Permalink
Chat templates by @maiqingqiang (#104)
Browse files Browse the repository at this point in the history
* add jinja package

* support chat template

* Support `addSpecialTokens`.

* Remove padding for now

We need to get back to this to support consistently.

---------

Co-authored-by: John Mai <[email protected]>
  • Loading branch information
pcuenca and johnmai-dev authored Sep 1, 2024
1 parent c088078 commit 5d89b5d
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ DerivedData/
.swiftpm/config/registries.json
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
.netrc
.idea
5 changes: 3 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ let package = Package(
.executable(name: "hub-cli", targets: ["HubCLI"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0")
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/maiqingqiang/Jinja", branch: "main")
],
targets: [
.executableTarget(
Expand All @@ -22,7 +23,7 @@ let package = Package(
.product(name: "ArgumentParser", package: "swift-argument-parser")]),
.executableTarget(name: "HubCLI", dependencies: ["Hub", .product(name: "ArgumentParser", package: "swift-argument-parser")]),
.target(name: "Hub", resources: [.process("FallbackConfigs")]),
.target(name: "Tokenizers", dependencies: ["Hub"]),
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "Jinja")]),
.target(name: "TensorUtils"),
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
Expand Down
2 changes: 1 addition & 1 deletion Sources/Tokenizers/BPETokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
self.unknownToken = nil
self.unknownTokenId = nil
}

eosToken = tokenizerConfig.eosToken?.stringValue
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]

Expand Down
22 changes: 12 additions & 10 deletions Sources/Tokenizers/PostProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import Foundation
import Hub

public protocol PostProcessor {
func postProcess(tokens: [String], tokensPair: [String]?) -> [String]
func callAsFunction(tokens: [String], tokensPair: [String]?) -> [String]
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
func callAsFunction(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]

init(config: Config)
}

extension PostProcessor {
func callAsFunction(tokens: [String], tokensPair: [String]? = nil) -> [String] {
return postProcess(tokens: tokens, tokensPair: tokensPair)
func callAsFunction(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
return postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens)
}
}

Expand Down Expand Up @@ -53,13 +53,15 @@ class TemplateProcessing: PostProcessor {
self.pair = pair
}

func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] {
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
let config = tokensPair == nil ? single : pair

var toReturn: [String] = []
for item in config {
if let specialToken = item.SpecialToken {
toReturn.append(specialToken.id!.stringValue!)
if addSpecialTokens {
toReturn.append(specialToken.id!.stringValue!)
}
} else if let sequence = item.Sequence {
if sequence.id?.stringValue == "A" {
toReturn += tokens
Expand All @@ -74,7 +76,7 @@ class TemplateProcessing: PostProcessor {

class ByteLevelPostProcessor: PostProcessor {
required public init(config: Config) {}
func postProcess(tokens: [String], tokensPair: [String]? = nil) -> [String] { tokens }
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens }
}

class RobertaProcessing: PostProcessor {
Expand All @@ -94,7 +96,7 @@ class RobertaProcessing: PostProcessor {
self.addPrefixSpace = config.addPrefixSpace?.boolValue ?? true
}

func postProcess(tokens: [String], tokensPair: [String]?) -> [String] {
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
var outTokens = tokens
var tokensPair = tokensPair
if trimOffset {
Expand Down
82 changes: 75 additions & 7 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import Hub
import Foundation
import Jinja

enum TokenizerError : Error {
case missingConfig
Expand Down Expand Up @@ -98,7 +99,8 @@ public protocol Tokenizer {

/// Main entry point
func encode(text: String) -> [Int]
func callAsFunction(_ text: String) -> [Int]
func encode(text: String, addSpecialTokens: Bool) -> [Int]
func callAsFunction(_ text: String, addSpecialTokens: Bool) -> [Int]

/// Decode
func decode(tokens: [Int]) -> String
Expand All @@ -115,11 +117,21 @@ public protocol Tokenizer {
var eosTokenId: Int? { get }
var unknownToken: String? { get }
var unknownTokenId: Int? { get }

func applyChatTemplate(messages: [[String: String]]) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
) throws -> [Int]
}

public extension Tokenizer {
func callAsFunction(_ text: String) -> [Int] {
encode(text: text)
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
encode(text: text, addSpecialTokens: addSpecialTokens)
}

func convertTokensToIds(_ tokens: [String]) -> [Int?] {
Expand All @@ -131,6 +143,17 @@ public extension Tokenizer {
}
}

let specialTokenAttributes: [String] = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens"
]

public class PreTrainedTokenizer: Tokenizer {
let model: TokenizingModel

Expand All @@ -150,8 +173,11 @@ public class PreTrainedTokenizer: Tokenizer {
private let normalizer: Normalizer?
private let postProcessor: PostProcessor?
private let decoder: Decoder?
private let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool

private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -195,7 +221,8 @@ public class PreTrainedTokenizer: Tokenizer {
self.postProcessor = PostProcessorFactory.fromConfig(config: tokenizerData.postProcessor)
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true

self.tokenizerConfig = tokenizerConfig

model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

Expand All @@ -209,9 +236,9 @@ public class PreTrainedTokenizer: Tokenizer {
return normalizer(text: text)
}

func postProcess(_ tokens: [String]) -> [String] {
func postProcess(_ tokens: [String], addSpecialTokens: Bool = true) -> [String] {
guard let postProcessor = postProcessor else { return tokens }
return postProcessor(tokens: tokens)
return postProcessor(tokens: tokens, addSpecialTokens: addSpecialTokens)
}

func decodeTokens(_ tokens: [String]) -> [String] {
Expand Down Expand Up @@ -265,8 +292,12 @@ public class PreTrainedTokenizer: Tokenizer {
}

/// Main entry point
public func encode(text: String, addSpecialTokens: Bool = true) -> [Int] {
return postProcess(tokenize(text: text), addSpecialTokens: addSpecialTokens).map { model.convertTokenToId($0)! }
}

public func encode(text: String) -> [Int] {
return postProcess(tokenize(text: text)).map { model.convertTokenToId($0)! }
return encode(text: text, addSpecialTokens: true)
}

/// Decode
Expand All @@ -285,6 +316,43 @@ public class PreTrainedTokenizer: Tokenizer {
public func convertIdToToken(_ id: Int) -> String? {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
}

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int?
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
]

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
if specialTokenAttributes.contains(key), !(value is NSNull) {
context[key] = value
}
}

let rendered = try template.render(context)
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
var maxLength = maxLength ?? encodedTokens.count
maxLength = min(maxLength, tokenizerConfig.modelMaxLength?.intValue ?? maxLength)
if encodedTokens.count > maxLength {
if truncation {
encodedTokens = Array(encodedTokens.prefix(maxLength))
}
}

return encodedTokens
}
}

// MARK: - Building
Expand Down

0 comments on commit 5d89b5d

Please sign in to comment.