Skip to content

Commit 02172ea

Browse files
committed
Add chatdocs.yml configuration file
Replace all command-line options with the new `chatdocs.yml` config
1 parent 9726b9f commit 02172ea

15 files changed

+229
-177
lines changed

.github/workflows/tests.yml

+7-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ jobs:
3535
python -m pip install --upgrade pip
3636
pip install .
3737
38+
- name: Copy chatdocs.yml
39+
run: cp tests/fixtures/chatdocs.yml .
40+
41+
- name: Test download
42+
run: chatdocs download
43+
3844
- name: Test add
3945
run: chatdocs add examples
4046

4147
- name: Test chat
42-
run: chatdocs chat 'Why was the NATO created?' --model marella/gpt-2-ggml --download --lib avx
48+
run: chatdocs chat 'Why was the NATO created?'

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
/chatdocs.yml
12
/db/
23

34
# Created by https://www.toptal.com/developers/gitignore/api/c++,python,cmake,linux,macos,windows,sublimetext,vim,visualstudio,visualstudiocode

chatdocs/add.py

+6-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import glob
3-
from typing import List
3+
from typing import Any, Dict, List
44
from multiprocessing import Pool
55

66
from tqdm import tqdm
@@ -18,12 +18,9 @@
1818
UnstructuredWordDocumentLoader,
1919
)
2020
from langchain.text_splitter import RecursiveCharacterTextSplitter
21-
from langchain.vectorstores import Chroma
22-
from langchain.embeddings import HuggingFaceInstructEmbeddings
2321
from langchain.docstore.document import Document
24-
from chromadb.config import Settings
2522

26-
from . import config
23+
from .vectorstores import get_vectorstore, get_vectorstore_from_documents
2724

2825

2926
# Custom document loaders
@@ -143,23 +140,12 @@ def does_vectorstore_exist(persist_directory: str) -> bool:
143140
return False
144141

145142

146-
def add(source_directory: str, persist_directory: str) -> None:
147-
# Create embeddings
148-
embeddings = HuggingFaceInstructEmbeddings(model_name=config.EMBEDDINGS_MODEL)
149-
chroma_settings = Settings(
150-
chroma_db_impl=config.CHROMA_DB_IMPL,
151-
persist_directory=persist_directory,
152-
anonymized_telemetry=False,
153-
)
154-
143+
def add(config: Dict[str, Any], source_directory: str) -> None:
144+
persist_directory = config["chroma"]["persist_directory"]
155145
if does_vectorstore_exist(persist_directory):
156146
# Update and store locally vectorstore
157147
print(f"Appending to existing vectorstore at {persist_directory}")
158-
db = Chroma(
159-
persist_directory=persist_directory,
160-
embedding_function=embeddings,
161-
client_settings=chroma_settings,
162-
)
148+
db = get_vectorstore(config)
163149
collection = db.get()
164150
texts = process_documents(
165151
source_directory,
@@ -172,11 +158,6 @@ def add(source_directory: str, persist_directory: str) -> None:
172158
print("Creating new vectorstore")
173159
texts = process_documents(source_directory)
174160
print(f"Creating embeddings. May take a few minutes...")
175-
db = Chroma.from_documents(
176-
texts,
177-
embeddings,
178-
persist_directory=persist_directory,
179-
client_settings=chroma_settings,
180-
)
161+
db = get_vectorstore_from_documents(config, texts)
181162
db.persist()
182163
db = None

chatdocs/chains.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Any, Dict
2+
3+
from langchain.chains import RetrievalQA
4+
5+
from .llms import get_llm
6+
from .vectorstores import get_vectorstore
7+
8+
9+
def get_retrieval_qa(config: Dict[str, Any]) -> RetrievalQA:
10+
db = get_vectorstore(config)
11+
retriever = db.as_retriever(**config["retriever"])
12+
llm = get_llm(config)
13+
return RetrievalQA.from_chain_type(
14+
llm=llm,
15+
retriever=retriever,
16+
return_source_documents=True,
17+
)

chatdocs/chat.py

+7-67
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,15 @@
1-
from typing import Any, Optional
1+
from typing import Any, Dict, Optional
22

3-
from chromadb.config import Settings
4-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
5-
from langchain.chains import RetrievalQA
6-
from langchain.embeddings import HuggingFaceInstructEmbeddings
7-
from langchain.llms import CTransformers, HuggingFacePipeline
8-
from langchain.vectorstores import Chroma
93
from rich import print
104
from rich.markup import escape
115
from rich.panel import Panel
126

13-
from . import config
7+
from .chains import get_retrieval_qa
8+
from .utils import print_answer
149

1510

16-
def print_response(text: str) -> None:
17-
print(f"[bright_cyan]{escape(text)}", end="", flush=True)
18-
19-
20-
class StreamingPrintCallbackHandler(StreamingStdOutCallbackHandler):
21-
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
22-
print_response(token)
23-
24-
25-
def chat(
26-
*,
27-
persist_directory: str,
28-
hf: bool,
29-
download: bool,
30-
model: str,
31-
model_type: Optional[str] = None,
32-
model_file: Optional[str] = None,
33-
lib: Optional[str] = None,
34-
query: Optional[str] = None,
35-
) -> None:
36-
local_files_only = not download
37-
embeddings = HuggingFaceInstructEmbeddings(model_name=config.EMBEDDINGS_MODEL)
38-
chroma_settings = Settings(
39-
chroma_db_impl=config.CHROMA_DB_IMPL,
40-
persist_directory=persist_directory,
41-
anonymized_telemetry=False,
42-
)
43-
db = Chroma(
44-
persist_directory=persist_directory,
45-
embedding_function=embeddings,
46-
client_settings=chroma_settings,
47-
)
48-
retriever = db.as_retriever(search_kwargs={"k": 4})
49-
50-
if hf:
51-
llm = HuggingFacePipeline.from_model_id(
52-
model_id=model,
53-
task="text-generation",
54-
model_kwargs={"local_files_only": local_files_only},
55-
pipeline_kwargs={"max_new_tokens": 256},
56-
)
57-
else:
58-
llm = CTransformers(
59-
model=model,
60-
model_type=model_type,
61-
model_file=model_file,
62-
config={"context_length": 1024, "local_files_only": local_files_only},
63-
lib=lib,
64-
callbacks=[StreamingPrintCallbackHandler()],
65-
)
66-
67-
qa = RetrievalQA.from_chain_type(
68-
llm=llm,
69-
chain_type="stuff",
70-
retriever=retriever,
71-
return_source_documents=True,
72-
)
11+
def chat(config: Dict[str, Any], query: Optional[str] = None) -> None:
12+
qa = get_retrieval_qa(config)
7313

7414
interactive = not query
7515
print()
@@ -89,8 +29,8 @@ def chat(
8929
print("[bold]A:", end="", flush=True)
9030

9131
res = qa(query)
92-
if hf:
93-
print_response(res["result"])
32+
if config["llm"] != "ctransformers":
33+
print_answer(res["result"])
9434

9535
print()
9636
for doc in res["source_documents"]:

chatdocs/config.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
1-
EMBEDDINGS_MODEL = "hkunlp/instructor-large"
2-
MODEL = "TheBloke/Wizard-Vicuna-7B-Uncensored-GGML"
3-
MODEL_TYPE = "llama"
1+
from pathlib import Path
2+
from typing import Any, Dict, Optional, Union
43

5-
CHROMA_DB_IMPL = "duckdb+parquet"
6-
PERSIST_DIRECTORY = "db"
4+
import yaml
5+
6+
from .utils import merge
7+
8+
FILENAME = "chatdocs.yml"
9+
10+
11+
def _get_config(path: Union[Path, str]) -> Dict[str, Any]:
12+
path = Path(path)
13+
if path.is_dir():
14+
path = path / FILENAME
15+
with open(path) as f:
16+
return yaml.safe_load(f)
17+
18+
19+
def get_config(path: Optional[Union[Path, str]] = None) -> Dict[str, Any]:
20+
default_config = _get_config(Path(__file__).parent / "data")
21+
if path is None:
22+
path = Path() / FILENAME
23+
if not path.is_file():
24+
return default_config
25+
config = _get_config(path)
26+
return merge(default_config, config)

chatdocs/data/chatdocs.yml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
embeddings:
2+
model: hkunlp/instructor-large
3+
4+
llm: ctransformers
5+
6+
ctransformers:
7+
model: TheBloke/Wizard-Vicuna-7B-Uncensored-GGML
8+
model_type: llama
9+
config:
10+
context_length: 1024
11+
12+
huggingface:
13+
model: TheBloke/Wizard-Vicuna-7B-Uncensored-HF
14+
pipeline_kwargs:
15+
max_new_tokens: 256
16+
17+
download: false
18+
19+
chroma:
20+
persist_directory: db
21+
chroma_db_impl: duckdb+parquet
22+
anonymized_telemetry: false
23+
24+
retriever:
25+
search_kwargs:
26+
k: 4

chatdocs/download.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from langchain.embeddings import HuggingFaceInstructEmbeddings
2-
from langchain.llms import CTransformers
1+
from typing import Any, Dict
32

4-
from . import config
3+
from .embeddings import get_embeddings
4+
from .llms import get_llm
55

66

7-
def download() -> None:
8-
HuggingFaceInstructEmbeddings(model_name=config.EMBEDDINGS_MODEL)
9-
CTransformers(model=config.MODEL, model_type=config.MODEL_TYPE)
7+
def download(config: Dict[str, Any]) -> None:
8+
config = {**config, "download": True}
9+
get_embeddings(config)
10+
get_llm(config)

chatdocs/embeddings.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any, Dict
2+
3+
from langchain.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
4+
from langchain.embeddings.base import Embeddings
5+
6+
7+
def get_embeddings(config: Dict[str, Any]) -> Embeddings:
8+
config = {**config["embeddings"]}
9+
config["model_name"] = config.pop("model")
10+
if config["model_name"].startswith("hkunlp/"):
11+
Provider = HuggingFaceInstructEmbeddings
12+
else:
13+
Provider = HuggingFaceEmbeddings
14+
return Provider(**config)

chatdocs/llms.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any, Dict
2+
3+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
4+
from langchain.llms import CTransformers, HuggingFacePipeline
5+
from langchain.llms.base import LLM
6+
7+
from .utils import merge, print_answer
8+
9+
10+
class StreamingPrintCallbackHandler(StreamingStdOutCallbackHandler):
11+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
12+
print_answer(token)
13+
14+
15+
def get_llm(config: Dict[str, Any]) -> LLM:
16+
local_files_only = not config["download"]
17+
if config["llm"] == "ctransformers":
18+
config = {**config["ctransformers"]}
19+
config = merge(config, {"config": {"local_files_only": local_files_only}})
20+
llm = CTransformers(callbacks=[StreamingPrintCallbackHandler()], **config)
21+
else:
22+
config = {**config["huggingface"]}
23+
config["model_id"] = config.pop("model")
24+
config = merge(config, {"model_kwargs": {"local_files_only": local_files_only}})
25+
llm = HuggingFacePipeline.from_model_id(task="text-generation", **config)
26+
return llm

0 commit comments

Comments
 (0)