-
Notifications
You must be signed in to change notification settings - Fork 289
Add Qwen3-0.6B language model support to pythainlp.lm and improve type annotations #1217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 6 commits
183748c
9654117
0e4b368
e53c777
fe937d2
2018813
69fbb42
c816818
0d104dd
772a80b
91b7bda
e2a14de
b29e085
f55c1c6
c73b2a3
79a4204
e775b9e
0b031fc
94d4314
1339261
5b3416a
5ecd247
88762b0
1e3f0b6
28869e0
06b5d5f
2d06f22
aa8751b
3d550ce
0b52d0d
b9bf7fc
ff07202
df97e0c
ffb73ee
2a9f6e9
91303d3
432c2ab
f529a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,21 @@ | |
| # SPDX-FileType: SOURCE | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| __all__ = ["calculate_ngram_counts", "remove_repeated_ngrams"] | ||
| __all__ = ["calculate_ngram_counts", "remove_repeated_ngrams", "Qwen3"] | ||
bact marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| from pythainlp.lm.text_util import ( | ||
| calculate_ngram_counts, | ||
| remove_repeated_ngrams, | ||
| ) | ||
|
|
||
| try: | ||
| from pythainlp.lm.qwen3 import Qwen3 | ||
| except ImportError: | ||
| # If dependencies are not installed, make Qwen3 available but raise | ||
| # error when instantiated | ||
| class Qwen3: # type: ignore | ||
| def __init__(self): | ||
| raise ImportError( | ||
| "Qwen3 requires additional dependencies. " | ||
| "Install with: pip install pythainlp[qwen3]" | ||
| ) | ||
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,197 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bact marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileType: SOURCE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
bact marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class Qwen3: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Qwen3-0.6B language model for Thai text generation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A small but capable language model from Alibaba Cloud's Qwen family, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| optimized for various NLP tasks including Thai language processing. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.device = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.torch_dtype = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model_path = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def load_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_path: str = "Qwen/Qwen3-0.6B", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device: str = "cuda", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_dtype=torch.float16, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| low_cpu_mem_usage: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Load Qwen3 model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :param str model_path: model path or HuggingFace model ID | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :param str device: device (cpu, cuda or other) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :param torch_dtype: torch data type (e.g., torch.float16, torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :param bool low_cpu_mem_usage: low cpu mem usage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :Example: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| :: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pythainlp.lm import Qwen3 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = Qwen3() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model.load_model(device="cpu", torch_dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.device = device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.torch_dtype = torch_dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model_path = model_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = AutoModelForCausalLM.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch_dtype=torch_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.to(device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| ) | |
| self.model.to(device) | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | |
| except OSError as exc: | |
| raise RuntimeError( | |
| f"Failed to load tokenizer from '{self.model_path}'. " | |
| "Check the model path or your network connection." | |
| ) from exc | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| ) | |
| except OSError as exc: | |
| raise RuntimeError( | |
| f"Failed to load model from '{self.model_path}'. " | |
| "This can happen due to an invalid model path, missing files, " | |
| "or insufficient disk space." | |
| ) from exc | |
| except RuntimeError as exc: | |
| raise RuntimeError( | |
| "Failed to load model weights. " | |
| "This can be caused by insufficient memory or an incompatible " | |
| "torch_dtype setting." | |
| ) from exc | |
| if isinstance(device, str) and device.startswith("cuda"): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "CUDA device requested but CUDA is not available. " | |
| "Check your PyTorch installation and GPU drivers, or use " | |
| "device='cpu' instead." | |
| ) | |
| try: | |
| self.model.to(device) | |
| except RuntimeError as exc: | |
| raise RuntimeError( | |
| f"Failed to move model to device '{device}'. " | |
| "Ensure the device exists and has enough memory, and that your " | |
| "PyTorch installation supports this device." | |
| ) from exc |
Copilot
AI
Feb 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate method lacks input validation for the text parameter. If an empty string or None is passed, it may cause unclear errors downstream. Consider adding validation to check that text is a non-empty string before processing.
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do_sample=False (greedy decoding), the temperature, top_p, and top_k parameters are ignored by the transformers library. Consider adding validation to warn users or handle this case explicitly. The current implementation may mislead users who set do_sample=False but also provide temperature values expecting them to have an effect.
Copilot
AI
Feb 5, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The slicing operation output_ids[0][len(input_ids[0]):] assumes that output_ids[0] and input_ids[0] are present. While this should generally be safe given the model.generate call, it would be more defensive to check the shapes or handle potential IndexError. Consider adding a check or comment explaining why this is safe in this context.
| # Decode only the newly generated tokens | |
| generated_text = self.tokenizer.decode( | |
| output_ids[0][len(input_ids[0]) :], | |
| # Decode only the newly generated tokens. | |
| if ( | |
| output_ids.dim() == 2 | |
| and input_ids.dim() == 2 | |
| and output_ids.size(0) > 0 | |
| and input_ids.size(0) > 0 | |
| ): | |
| start_idx = input_ids.size(1) | |
| generated_ids = output_ids[0, start_idx:] | |
| else: | |
| raise RuntimeError( | |
| "Unexpected tensor shape from model.generate(); " | |
| "expected 2D tensors with non-empty batch dimension." | |
| ) | |
| generated_text = self.tokenizer.decode( | |
| generated_ids, |
Copilot
AI
Feb 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chat method lacks input validation for the messages parameter. If an empty list is passed, the method will proceed without error but may produce unexpected behavior. Consider adding a check to validate that messages is not empty and contains properly formatted message dictionaries with 'role' and 'content' keys.
Outdated
Copilot
AI
Feb 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback chat template formatting (lines 172-177) could produce ambiguous output if message content contains newline characters. Consider either sanitizing the content to remove/escape newlines, or using a more robust delimiter that won't be present in natural text.
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| role = str(msg.get("role", "user")).replace("\n", " ") | |
| content = str(msg.get("content", "")).replace("\n", "\\n") |
bact marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project | ||
| # SPDX-FileType: SOURCE | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import unittest | ||
|
|
||
| from pythainlp.lm import Qwen3 | ||
|
|
||
|
|
||
| class LMTestCaseX(unittest.TestCase): | ||
|
||
| def test_qwen3_initialization(self): | ||
| # Test that Qwen3 can be instantiated | ||
| try: | ||
| model = Qwen3() | ||
| self.assertIsNotNone(model) | ||
| self.assertIsNone(model.model) | ||
| self.assertIsNone(model.tokenizer) | ||
| except ImportError: | ||
| # Skip if dependencies not installed | ||
| self.skipTest("Qwen3 dependencies not installed") | ||
|
|
||
| def test_qwen3_generate_without_load(self): | ||
| # Test that generate raises error when model is not loaded | ||
| try: | ||
| model = Qwen3() | ||
| with self.assertRaises(RuntimeError): | ||
| model.generate("test") | ||
| except ImportError: | ||
| # Skip if dependencies not installed | ||
| self.skipTest("Qwen3 dependencies not installed") | ||
|
|
||
| def test_qwen3_chat_without_load(self): | ||
| # Test that chat raises error when model is not loaded | ||
| try: | ||
| model = Qwen3() | ||
| with self.assertRaises(RuntimeError): | ||
| model.chat([{"role": "user", "content": "test"}]) | ||
| except ImportError: | ||
| # Skip if dependencies not installed | ||
| self.skipTest("Qwen3 dependencies not installed") | ||
Uh oh!
There was an error while loading. Please reload this page.