diff --git a/l2p/llm_builder.py b/l2p/llm_builder.py index 8685d52..c30d0fa 100644 --- a/l2p/llm_builder.py +++ b/l2p/llm_builder.py @@ -269,6 +269,8 @@ def __init__(self, model_path: str, max_tokens=4e3, temperature=0.01, top_p=0.9) self.top_p = top_p self.in_tokens = 0 self.out_tokens = 0 + + self._load_transformers() def _load_transformers(self): @@ -300,11 +302,12 @@ def _load_transformers(self): ) try: + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Check if the model_path is valid by trying to load it self.model = transformers.pipeline( "text-generation", model=self.model_path, - model_kwargs={"torch_dtype": torch.bfloat16}, + model_kwargs={"torch_dtype": dtype}, device_map="auto", ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)