diff --git a/.gitignore b/.gitignore index 6ff9ed5..d28fdb3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.backup *.history *.log +*.json +*.pyc diff --git a/requirements.txt b/requirements.txt index 5eda85b..74e99f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ requests==2.31.0 rich==13.7.0 urllib3==2.2.1 prompt-toolkit==3.0.43 +pyperclip==1.8.2 diff --git a/rich_chat/__init__.py b/rich_chat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rich_chat/chat_history.py b/rich_chat/chat_history.py new file mode 100644 index 0000000..ff26a31 --- /dev/null +++ b/rich_chat/chat_history.py @@ -0,0 +1,114 @@ +import json +import os +import re +from pathlib import Path +from typing import Dict, List + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.clipboard.pyperclip import PyperclipClipboard +from prompt_toolkit.history import FileHistory +from prompt_toolkit.key_binding import KeyBindings + + +class ChatHistory: + def __init__(self, session_name: str, system_message: str = None): + # Define the cache path for storing chat history + home = os.environ.get("HOME", ".") # get user's home path, else assume cwd + cache = Path(f"{home}/.cache/rich-chat") # set the cache path + cache.mkdir(parents=True, exist_ok=True) # ensure the directory exists + + # Define the file path for storing chat history + self.file_path = cache / f"{session_name}.json" + + # Define the file path for storing prompt session history + file_history_path = cache / f"{session_name}.history" + self.session = PromptSession(history=FileHistory(file_history_path)) + self.auto_suggest = AutoSuggestFromHistory() + + # Define the list for tracking chat messages. + # Each message is a dictionary with the following structure: + # {"role": "user/assistant/system", "content": ""} + self.messages: List[Dict[str, str]] = [] + if system_message is not None: + self.messages.append({"role": "system", "content": system_message}) + + @property + def key_bindings(self) -> KeyBindings: + kb = KeyBindings() + clipboard = PyperclipClipboard() + + for i in range(9): + + @kb.add("c-s", "a", str(i)) + def _(event): + """Copy the entire last message to the system clipboard.""" + if self.messages: + # this doesn't auto-update. we need to re-render the toolbar somehow. + self.bottom_toolbar = "Copied last message into clipboard!" + # look at the last key + key = int(event.key_sequence[-1].key) + # look at the content with the given key + # note: referenced key may not exist and can trigger a IndexError + last_message_content = self.messages[-key]["content"].strip() + clipboard.set_text(last_message_content) + + @kb.add("c-s", "s", str(i)) + def _(event): + """Copy only code snippets from the last message to the system clipboard.""" + if self.messages: + self.bottom_toolbar = ( + "Copied code blocks from last message into clipboard!" + ) + key = int(event.key_sequence[-1].key) + last_message_content = self.messages[-key]["content"].strip() + code_snippets = re.findall( + r"```(.*?)```", last_message_content, re.DOTALL + ) + snippets_content = "\n\n".join(code_snippets) + clipboard.set_text(snippets_content) + + return kb + + def load(self) -> List[Dict[str, str]]: + try: + with open(self.file_path, "r") as chat_session: + self.messages = json.load(chat_session) + return self.messages + except (FileNotFoundError, json.JSONDecodeError): + self.save() # create the missing file + print(f"ChatHistoryLoad: Created new cache: {self.file_path}") + + def save(self) -> None: + try: + with open(self.file_path, "w") as chat_session: + json.dump(self.messages, chat_session, indent=2) + except TypeError as e: + print(f"ChatHistoryWrite: {e}") + + def prompt(self) -> str: + # Prompt the user for input + return self.session.prompt( + "Prompt: (⌥ + ⏎) | Copy: ((⌘ + s) (a|s) (.[0-9])) | Exit: (⌘ + c):\n", + key_bindings=self.key_bindings, + auto_suggest=self.auto_suggest, + multiline=True, + ).strip() + + def append(self, message: Dict[str, str]) -> None: + self.messages.append(message) + + def insert(self, index: int, element: object) -> None: + self.messages.insert(index, element) + + def pop(self, index: int) -> Dict[str, str]: + return self.messages.pop(index) + + def replace(self, index: int, content: str) -> None: + try: + self.messages[index]["content"] = content + except (IndexError, KeyError) as e: + print(f"ChatHistoryReplace: Failed to substitute chat message: {e}") + + def reset(self) -> None: + self.messages = [] diff --git a/source/rich-chat.py b/rich_chat/rich_chat_cli.py similarity index 73% rename from source/rich-chat.py rename to rich_chat/rich_chat_cli.py index 9bb8b8b..0883bd0 100644 --- a/source/rich-chat.py +++ b/rich_chat/rich_chat_cli.py @@ -1,15 +1,18 @@ +#!/usr/bin/env python + import argparse import json import os import requests -from prompt_toolkit import PromptSession -from prompt_toolkit.history import FileHistory +from prompt_toolkit import prompt as input from rich.console import Console from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel +from rich_chat.chat_history import ChatHistory + def remove_lines_console(num_lines): for _ in range(num_lines): @@ -28,14 +31,13 @@ def estimate_lines(text): return line_count -def handle_console_input(session: PromptSession) -> str: - return session.prompt("(Prompt: ⌥ + ⏎) | (Exit: ⌘ + c):\n", multiline=True).strip() - - class conchat: def __init__( self, server_addr, + min_p: float, + repeat_penalty: float, + seed: int, top_k=10, top_p=0.95, temperature=0.12, @@ -43,36 +45,54 @@ def __init__( stream: bool = True, cache_prompt: bool = True, model_frame_color: str = "red", + chat_history: ChatHistory = None, ) -> None: self.model_frame_color = model_frame_color self.serveraddr = server_addr self.topk = top_k self.top_p = top_p + self.seed = seed + self.min_p = min_p + self.repeat_penalty = repeat_penalty self.temperature = temperature self.n_predict = n_predict self.stream = stream self.cache_prompt = cache_prompt self.headers = {"Content-Type": "application/json"} - self.chat_history = [] + self.chat_history = chat_history self.model_name = "" self.console = Console() - # TODO: Gracefully handle user input history file. - self.session = PromptSession(history=FileHistory(".rich-chat.history")) + self._render_messages_once_on_start() + + def _render_messages_once_on_start(self) -> None: + self.chat_history.load() + for message in self.chat_history.messages: + title = message["role"] if message["role"] != "user" else "HUMAN" + self.console.print( + Panel( + Markdown(message["content"]), + title=title.upper(), + title_align="left", + ) + ) def chat_generator(self, prompt): endpoint = self.serveraddr + "/v1/chat/completions" self.chat_history.append({"role": "user", "content": prompt}) payload = { - "messages": self.chat_history, + "messages": self.chat_history.messages, "temperature": self.temperature, "top_k": self.topk, "top_p": self.top_p, "n_predict": self.n_predict, "stream": self.stream, "cache_prompt": self.cache_prompt, + "seed": self.seed, + "repeat_penalty": self.repeat_penalty, + "min_p": self.min_p, } try: response = requests.post( @@ -150,7 +170,7 @@ def chat(self): self.model_name = self.get_model_name() while True: try: - user_m = handle_console_input(self.session) + user_m = self.chat_history.prompt() remove_lines_console(estimate_lines(text=user_m)) self.console.print( Panel(Markdown(user_m), title="HUMAN", title_align="left") @@ -160,6 +180,7 @@ def chat(self): # NOTE: Ctrl + c (keyboard) or Ctrl + d (eof) to exit # Adding EOFError prevents an exception and gracefully exits. except (KeyboardInterrupt, EOFError): + self.chat_history.save() exit() @@ -194,17 +215,54 @@ def main(): type=int, help="The number defines how many tokens to be predict by the model. Default: infinity until [stop] token.", ) + parser.add_argument( + "--minp", + type=float, + default=0.5, + help="The minimum probability for a token to be considered, relative to the probability of the most likely token (default: 0.05).", + ) + parser.add_argument( + "--repeat-penalty", + type=float, + default=1.1, + help="Control the repetition of token sequences in the generated text (default: 1.1).", + ) + parser.add_argument( + "--seed", + type=int, + default=-1, + help="Set the random number generator (RNG) seed (default: -1, -1 = random seed).", + ) + parser.add_argument( + "-m", + "--system-message", + type=str, + default=None, # empty by default; avoiding assumptions. + help="The system message used to orientate the model, if any.", + ) + parser.add_argument( + "-n", + "--session-name", + type=str, + default="rich-chat", + help="The name of the chat session. Default is 'rich-chat'.", + ) args = parser.parse_args() - # print(args) - # print(f"ARG of server is {args.server}") - # print(f"argument of bot color is {args.model_frame_color}") + + # Defaults to Path(".") if args.chat_history is "" + chat_history = ChatHistory(args.session_name, args.system_message) + chat = conchat( server_addr=args.server, top_k=args.topk, top_p=args.topp, temperature=args.temperature, model_frame_color=args.model_frame_color, + min_p=args.minp, + seed=args.seed, + repeat_penalty=args.repeat_penalty, + chat_history=chat_history, ) chat.chat()