Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
greenrazer authored Jan 29, 2025
2 parents 307a26c + ff81749 commit 7897a7e
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ let package = Package(
.executable(name: "hub-cli", targets: ["HubCLI"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.4.0"),
.package(url: "https://github.com/johnmai-dev/Jinja", from: "1.1.0")
.package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMinor(from: "1.4.0")),
.package(url: "https://github.com/johnmai-dev/Jinja", .upToNextMinor(from: "1.1.0"))
],
targets: [
.executableTarget(
Expand Down
8 changes: 6 additions & 2 deletions Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ public extension LanguageModel {
var tokenizer: Tokenizer {
get async throws {
guard _tokenizer == nil else { return _tokenizer! }
guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
guard let tokenizerConfig = try await tokenizerConfig else {
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return _tokenizer!
Expand All @@ -212,4 +214,6 @@ extension LanguageModel: TextGenerationModel {
}
}

extension String: Error {}
public enum TokenizerError: Error {
case tokenizerConfigNotFound
}
102 changes: 88 additions & 14 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import Hub
import Foundation
import Jinja

public typealias Message = [String: Any]
public typealias ToolSpec = [String: Any]

enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
Expand Down Expand Up @@ -142,23 +145,57 @@ public protocol Tokenizer {
var unknownTokenId: Int? { get }

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
func applyChatTemplate(messages: [Message]) throws -> [Int]

/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]

/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [[String: Any]]?
tools: [ToolSpec]?
) throws -> [Int]

func applyChatTemplate(
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [ToolSpec]?,
additionalContext: [String: Any]?
) throws -> [Int]
}

extension Tokenizer {
/// Call previous signature for backwards compatibility
func applyChatTemplate(
messages: [Message],
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?,
tools: [ToolSpec]?,
additionalContext: [String: Any]?
) throws -> [Int] {
if additionalContext == nil {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
} else {
throw TokenizerError.chatTemplate("Not implemented")
}
}
}

public extension Tokenizer {
Expand Down Expand Up @@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
}

public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
-> [Int]
{
try applyChatTemplate(
messages: messages,
addGenerationPrompt: true,
tools: tools,
additionalContext: additionalContext
)
}

public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int? = nil,
tools: [ToolSpec]? = nil
) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil)
}

public func applyChatTemplate(
messages: [Message],
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
Expand All @@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer {
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
tools: [ToolSpec]? = nil,
additionalContext: [String: Any]? = nil
) throws -> [Int] {
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
Expand Down Expand Up @@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer {
let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
"add_generation_prompt": addGenerationPrompt,
]
if let tools {
context["tools"] = tools
}
if let additionalContext {
/*
Additional keys and values to be added to the context provided to the prompt templating engine.
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
*/
for (key, value) in additionalContext {
context[key] = value
}
}

// TODO: maybe keep NSString here
for (key, value) in tokenizerConfig.dictionary as [String : Any] {
Expand Down
90 changes: 89 additions & 1 deletion Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,93 @@ class ChatTemplateTests: XCTestCase {
XCTAssertEqual(decoded, decodedTarget)
}

// TODO: Add tests for tool use template
func testQwen2_5WithTools() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Qwen2.5-7B-Instruct-4bit")

let weatherQueryMessages: [[String: String]] = [
[
"role": "user",
"content": "What is the weather in Paris today?",
]
]

let getCurrentWeatherToolSpec: [String: Any] = [
"type": "function",
"function": [
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": [
"type": "object",
"properties": [
"location": [
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
],
"unit": [
"type": "string",
"enum": ["celsius", "fahrenheit"]
]
],
"required": ["location"]
]
]
]

let encoded = try tokenizer.applyChatTemplate(messages: weatherQueryMessages, tools: [getCurrentWeatherToolSpec])
let decoded = tokenizer.decode(tokens: encoded)

func assertDictsAreEqual(_ actual: [String: Any], _ expected: [String: Any]) {
for (key, value) in actual {
if let nestedDict = value as? [String: Any], let nestedDict2 = expected[key] as? [String: Any] {
assertDictsAreEqual(nestedDict, nestedDict2)
} else if let arrayValue = value as? [String] {
let expectedArrayValue = expected[key] as? [String]
XCTAssertNotNil(expectedArrayValue)
XCTAssertEqual(Set(arrayValue), Set(expectedArrayValue!))
} else {
XCTAssertEqual(value as? String, expected[key] as? String)
}
}
}

if let startRange = decoded.range(of: "<tools>\n"),
let endRange = decoded.range(of: "\n</tools>", range: startRange.upperBound..<decoded.endIndex) {
let toolsSection = String(decoded[startRange.upperBound..<endRange.lowerBound])
if let toolsDict = try? JSONSerialization.jsonObject(with: toolsSection.data(using: .utf8)!) as? [String : Any] {
assertDictsAreEqual(toolsDict, getCurrentWeatherToolSpec)
} else {
XCTFail("Failed to decode tools section")
}
} else {
XCTFail("Failed to find tools section")
}

let expectedPromptStart = """
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
"""

let expectedPromptEnd = """
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call><|im_end|>
<|im_start|>user
What is the weather in Paris today?<|im_end|>
<|im_start|>assistant
"""

XCTAssertTrue(decoded.hasPrefix(expectedPromptStart), "Prompt should start with expected system message")
XCTAssertTrue(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format")
}
}
2 changes: 1 addition & 1 deletion Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class TokenizerTester {
do {
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
XCTFail("Cannot retrieve Tokenizer configuration")
return nil
return nil
}
let tokenizerData = try await configuration!.tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
Expand Down

0 comments on commit 7897a7e

Please sign in to comment.