-
Notifications
You must be signed in to change notification settings - Fork 289
Add Qwen3-0.6B language model support to pythainlp.lm and improve type annotations #1217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 36 commits
183748c
9654117
0e4b368
e53c777
fe937d2
2018813
69fbb42
c816818
0d104dd
772a80b
91b7bda
e2a14de
b29e085
f55c1c6
c73b2a3
79a4204
e775b9e
0b031fc
94d4314
1339261
5b3416a
5ecd247
88762b0
1e3f0b6
28869e0
06b5d5f
2d06f22
aa8751b
3d550ce
0b52d0d
b9bf7fc
ff07202
df97e0c
ffb73ee
2a9f6e9
91303d3
432c2ab
f529a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,260 @@ | ||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project | ||||||||||||||||||||||||||||||||||||||||||
bact marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileType: SOURCE | ||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||
bact marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Optional | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
| from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class Qwen3: | ||||||||||||||||||||||||||||||||||||||||||
| """Qwen3-0.6B language model for Thai text generation. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| A small but capable language model from Alibaba Cloud's Qwen family, | ||||||||||||||||||||||||||||||||||||||||||
| optimized for various NLP tasks including Thai language processing. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| self.model: Optional["PreTrainedModel"] = None | ||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer: Optional["PreTrainedTokenizerBase"] = None | ||||||||||||||||||||||||||||||||||||||||||
| self.device: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||
| self.torch_dtype: Optional["torch.dtype"] = None | ||||||||||||||||||||||||||||||||||||||||||
| self.model_path: Optional[str] = None | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def load_model( | ||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||
| model_path: str = "Qwen/Qwen3-0.6B", | ||||||||||||||||||||||||||||||||||||||||||
| device: str = "cuda", | ||||||||||||||||||||||||||||||||||||||||||
| torch_dtype: Optional["torch.dtype"] = None, | ||||||||||||||||||||||||||||||||||||||||||
| low_cpu_mem_usage: bool = True, | ||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||
| """Load Qwen3 model. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| :param str model_path: model path or HuggingFace model ID | ||||||||||||||||||||||||||||||||||||||||||
| :param str device: device (cpu, cuda or other) | ||||||||||||||||||||||||||||||||||||||||||
| :param Optional[torch.dtype] torch_dtype: torch data type (e.g., torch.float16, torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||
| :param bool low_cpu_mem_usage: low cpu mem usage | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| :Example: | ||||||||||||||||||||||||||||||||||||||||||
| :: | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from pythainlp.lm import Qwen3 | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| model = Qwen3() | ||||||||||||||||||||||||||||||||||||||||||
| model.load_model(device="cpu", torch_dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||
bact marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Set default torch_dtype if not provided | ||||||||||||||||||||||||||||||||||||||||||
| if torch_dtype is None: | ||||||||||||||||||||||||||||||||||||||||||
| torch_dtype = torch.float16 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Check CUDA availability early before loading model | ||||||||||||||||||||||||||||||||||||||||||
| if device.startswith("cuda"): | ||||||||||||||||||||||||||||||||||||||||||
| if not torch.cuda.is_available(): | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| "CUDA device requested but CUDA is not available. " | ||||||||||||||||||||||||||||||||||||||||||
| "Check your PyTorch installation and GPU drivers, or use " | ||||||||||||||||||||||||||||||||||||||||||
| "device='cpu' instead." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| self.device = device | ||||||||||||||||||||||||||||||||||||||||||
| self.torch_dtype = torch_dtype | ||||||||||||||||||||||||||||||||||||||||||
| self.model_path = model_path | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | ||||||||||||||||||||||||||||||||||||||||||
| except OSError as exc: | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| f"Failed to load tokenizer from '{self.model_path}'. " | ||||||||||||||||||||||||||||||||||||||||||
| "Check the model path or your network connection." | ||||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||
| self.model = AutoModelForCausalLM.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||
| self.model_path, | ||||||||||||||||||||||||||||||||||||||||||
| device_map=device, | ||||||||||||||||||||||||||||||||||||||||||
| torch_dtype=torch_dtype, | ||||||||||||||||||||||||||||||||||||||||||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| except OSError as exc: | ||||||||||||||||||||||||||||||||||||||||||
| # Clean up tokenizer on failure | ||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = None | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| f"Failed to load model from '{self.model_path}'. " | ||||||||||||||||||||||||||||||||||||||||||
| "This can happen due to an invalid model path, missing files, " | ||||||||||||||||||||||||||||||||||||||||||
| "or insufficient disk space." | ||||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||
| # Clean up tokenizer on failure | ||||||||||||||||||||||||||||||||||||||||||
| self.tokenizer = None | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| f"Failed to load model weights: {exc}. " | ||||||||||||||||||||||||||||||||||||||||||
| "This can be caused by insufficient memory, an incompatible " | ||||||||||||||||||||||||||||||||||||||||||
| "torch_dtype setting, or other configuration issues." | ||||||||||||||||||||||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def generate( | ||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||
| text: str, | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| max_new_tokens: int = 512, | ||||||||||||||||||||||||||||||||||||||||||
| temperature: float = 0.7, | ||||||||||||||||||||||||||||||||||||||||||
| top_p: float = 0.9, | ||||||||||||||||||||||||||||||||||||||||||
| top_k: int = 50, | ||||||||||||||||||||||||||||||||||||||||||
| do_sample: bool = True, | ||||||||||||||||||||||||||||||||||||||||||
| skip_special_tokens: bool = True, | ||||||||||||||||||||||||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||||||||||||||||||||||||
| """Generate text from a prompt. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| :param str text: input text prompt | ||||||||||||||||||||||||||||||||||||||||||
| :param int max_new_tokens: maximum number of new tokens to generate | ||||||||||||||||||||||||||||||||||||||||||
| :param float temperature: temperature for sampling (higher = more random) | ||||||||||||||||||||||||||||||||||||||||||
| :param float top_p: top p for nucleus sampling | ||||||||||||||||||||||||||||||||||||||||||
| :param int top_k: top k for top-k sampling | ||||||||||||||||||||||||||||||||||||||||||
| :param bool do_sample: whether to use sampling or greedy decoding | ||||||||||||||||||||||||||||||||||||||||||
| :param bool skip_special_tokens: skip special tokens in output | ||||||||||||||||||||||||||||||||||||||||||
| :return: generated text | ||||||||||||||||||||||||||||||||||||||||||
| :rtype: str | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| :Example: | ||||||||||||||||||||||||||||||||||||||||||
| :: | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from pythainlp.lm import Qwen3 | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| model = Qwen3() | ||||||||||||||||||||||||||||||||||||||||||
| model.load_model(device="cpu", torch_dtype=torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| result = model.generate("สวัสดี") | ||||||||||||||||||||||||||||||||||||||||||
| print(result) | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| if self.model is None or self.tokenizer is None or self.device is None: | ||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||
| "Model not loaded. Please call load_model() first." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if not text or not isinstance(text, str): | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||
| "text parameter must be a non-empty string." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| inputs = self.tokenizer(text, return_tensors="pt") | ||||||||||||||||||||||||||||||||||||||||||
| input_ids = inputs["input_ids"].to(self.device) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Note: When do_sample=False (greedy decoding), temperature, top_p, | ||||||||||||||||||||||||||||||||||||||||||
| # and top_k parameters are ignored by the transformers library | ||||||||||||||||||||||||||||||||||||||||||
| with torch.inference_mode(): | ||||||||||||||||||||||||||||||||||||||||||
| output_ids = self.model.generate( | ||||||||||||||||||||||||||||||||||||||||||
| input_ids, | ||||||||||||||||||||||||||||||||||||||||||
| max_new_tokens=max_new_tokens, | ||||||||||||||||||||||||||||||||||||||||||
| temperature=temperature, | ||||||||||||||||||||||||||||||||||||||||||
| top_p=top_p, | ||||||||||||||||||||||||||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||||||||||||||||||||||||||
| do_sample=do_sample, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+167
to
+174
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Decode only the newly generated tokens | ||||||||||||||||||||||||||||||||||||||||||
| # output_ids and input_ids are guaranteed to be 2D tensors with | ||||||||||||||||||||||||||||||||||||||||||
| # batch size 1 from the tokenizer call above | ||||||||||||||||||||||||||||||||||||||||||
| generated_text = self.tokenizer.decode( | ||||||||||||||||||||||||||||||||||||||||||
| output_ids[0][len(input_ids[0]) :], | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
176
to
180
|
||||||||||||||||||||||||||||||||||||||||||
| # Decode only the newly generated tokens | |
| generated_text = self.tokenizer.decode( | |
| output_ids[0][len(input_ids[0]) :], | |
| # Decode only the newly generated tokens. | |
| if ( | |
| output_ids.dim() == 2 | |
| and input_ids.dim() == 2 | |
| and output_ids.size(0) > 0 | |
| and input_ids.size(0) > 0 | |
| ): | |
| start_idx = input_ids.size(1) | |
| generated_ids = output_ids[0, start_idx:] | |
| else: | |
| raise RuntimeError( | |
| "Unexpected tensor shape from model.generate(); " | |
| "expected 2D tensors with non-empty batch dimension." | |
| ) | |
| generated_text = self.tokenizer.decode( | |
| generated_ids, |
Copilot
AI
Feb 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chat method lacks input validation for the messages parameter. If an empty list is passed, the method will proceed without error but may produce unexpected behavior. Consider adding a check to validate that messages is not empty and contains properly formatted message dictionaries with 'role' and 'content' keys.
Uh oh!
There was an error while loading. Please reload this page.