1
- from typing import Any , Optional
1
+ from typing import Any , Dict , Optional
2
2
3
- from chromadb .config import Settings
4
- from langchain .callbacks .streaming_stdout import StreamingStdOutCallbackHandler
5
- from langchain .chains import RetrievalQA
6
- from langchain .embeddings import HuggingFaceInstructEmbeddings
7
- from langchain .llms import CTransformers , HuggingFacePipeline
8
- from langchain .vectorstores import Chroma
9
3
from rich import print
10
4
from rich .markup import escape
11
5
from rich .panel import Panel
12
6
13
- from . import config
7
+ from .chains import get_retrieval_qa
8
+ from .utils import print_answer
14
9
15
10
16
- def print_response (text : str ) -> None :
17
- print (f"[bright_cyan]{ escape (text )} " , end = "" , flush = True )
18
-
19
-
20
- class StreamingPrintCallbackHandler (StreamingStdOutCallbackHandler ):
21
- def on_llm_new_token (self , token : str , ** kwargs : Any ) -> None :
22
- print_response (token )
23
-
24
-
25
- def chat (
26
- * ,
27
- persist_directory : str ,
28
- hf : bool ,
29
- download : bool ,
30
- model : str ,
31
- model_type : Optional [str ] = None ,
32
- model_file : Optional [str ] = None ,
33
- lib : Optional [str ] = None ,
34
- query : Optional [str ] = None ,
35
- ) -> None :
36
- local_files_only = not download
37
- embeddings = HuggingFaceInstructEmbeddings (model_name = config .EMBEDDINGS_MODEL )
38
- chroma_settings = Settings (
39
- chroma_db_impl = config .CHROMA_DB_IMPL ,
40
- persist_directory = persist_directory ,
41
- anonymized_telemetry = False ,
42
- )
43
- db = Chroma (
44
- persist_directory = persist_directory ,
45
- embedding_function = embeddings ,
46
- client_settings = chroma_settings ,
47
- )
48
- retriever = db .as_retriever (search_kwargs = {"k" : 4 })
49
-
50
- if hf :
51
- llm = HuggingFacePipeline .from_model_id (
52
- model_id = model ,
53
- task = "text-generation" ,
54
- model_kwargs = {"local_files_only" : local_files_only },
55
- pipeline_kwargs = {"max_new_tokens" : 256 },
56
- )
57
- else :
58
- llm = CTransformers (
59
- model = model ,
60
- model_type = model_type ,
61
- model_file = model_file ,
62
- config = {"context_length" : 1024 , "local_files_only" : local_files_only },
63
- lib = lib ,
64
- callbacks = [StreamingPrintCallbackHandler ()],
65
- )
66
-
67
- qa = RetrievalQA .from_chain_type (
68
- llm = llm ,
69
- chain_type = "stuff" ,
70
- retriever = retriever ,
71
- return_source_documents = True ,
72
- )
11
+ def chat (config : Dict [str , Any ], query : Optional [str ] = None ) -> None :
12
+ qa = get_retrieval_qa (config )
73
13
74
14
interactive = not query
75
15
print ()
@@ -89,8 +29,8 @@ def chat(
89
29
print ("[bold]A:" , end = "" , flush = True )
90
30
91
31
res = qa (query )
92
- if hf :
93
- print_response (res ["result" ])
32
+ if config [ "llm" ] != "ctransformers" :
33
+ print_answer (res ["result" ])
94
34
95
35
print ()
96
36
for doc in res ["source_documents" ]:
0 commit comments