From ddaf95370c9fd10aed12b3fb7c0786c740de75a0 Mon Sep 17 00:00:00 2001 From: Bulat Date: Mon, 20 Feb 2023 00:28:31 +0300 Subject: [PATCH 1/2] Update __init__.py add variable cache_dir to download models in specified directory --- galai/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/galai/__init__.py b/galai/__init__.py index ba535a6..863362d 100644 --- a/galai/__init__.py +++ b/galai/__init__.py @@ -17,6 +17,7 @@ def load_model( name: str, + cache_dir: str, dtype: Union[str, torch.dtype] = None, num_gpus: int = None, parallelize: bool = False @@ -128,6 +129,6 @@ def load_model( tensor_parallel=parallelize, ) model._set_tokenizer(hf_model) - model._load_checkpoint(checkpoint_path=hf_model) + model._load_checkpoint(checkpoint_path=hf_model, cache_dir=cache_dir) return model From 7fde8e93f6ccd4b776d9ef8da7fa7eda69389ad3 Mon Sep 17 00:00:00 2001 From: Bulat Date: Mon, 20 Feb 2023 00:30:11 +0300 Subject: [PATCH 2/2] Update model.py add variable cache_dir to download models in specified directory --- galai/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/galai/model.py b/galai/model.py index 866ffbb..fbf1610 100644 --- a/galai/model.py +++ b/galai/model.py @@ -82,7 +82,7 @@ def __init__( self.max_input_length = 2020 self._master_port = None - def _load_checkpoint(self, checkpoint_path: str): + def _load_checkpoint(self, checkpoint_path: str, cache_dir: str): """ Loads the checkpoint for the model @@ -108,6 +108,7 @@ def _load_checkpoint(self, checkpoint_path: str): self.model = OPTForCausalLM.from_pretrained( checkpoint_path, + cache_dir=cache_dir, torch_dtype=self.dtype, low_cpu_mem_usage=True, device_map=device_map,