Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Python
__pycache__/
*.py[cod]
*$py.class

# Entornos virtuales
venv/
env/
.env/

# Archivos de distribución
dist/
build/
*.egg-info/

# Archivos de IDE
.vscode/
.idea/
*.swp
*.swo

# Archivos de sistema
.DS_Store
Thumbs.db
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ dependencies = [
litellm = [
"litellm == 1.42.5"
]
gemini = [
"google-generativeai >= 0.3.2",
"shell_gpt[litellm]"
]
test = [
"pytest >= 7.2.2, < 8.0.0",
"requests-mock[fixture] >= 1.10.0, < 2.0.0",
Expand Down
48 changes: 47 additions & 1 deletion sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,42 @@
from sgpt.utils import (
get_edited_prompt,
get_sgpt_version,
install_shell_completion,
install_shell_integration,
list_models,
run_command,
)

app = typer.Typer(
add_completion=False,
)


# Add helper to display resolved default model in --help output.

def _resolved_default_model() -> str:
"""Return the effective default model taking provider into account.

If DEFAULT_MODEL is set to the placeholder 'auto_select_default_model',
we determine the concrete default depending on the selected provider
(OpenAI or Gemini). This value is only used for documentation purposes
so the runtime behaviour remains unchanged (the actual resolution still
happens deeper in the handlers).
"""
default_model = cfg.get('DEFAULT_MODEL')
if default_model != 'auto_select_default_model':
return default_model

provider = cfg.get('LLM_API_PROVIDER')
if provider == 'gemini':
return cfg.get('DEFAULT_MODEL_GEMINI')
if provider == 'openai':
return cfg.get('DEFAULT_MODEL_OPENAI')
# Fallback to the raw value if provider is unexpected
return default_model


@app.command()
def main(
prompt: str = typer.Argument(
"",
Expand All @@ -33,6 +64,7 @@ def main(
model: str = typer.Option(
cfg.get("DEFAULT_MODEL"),
help="Large language model to use.",
show_default=_resolved_default_model(),
),
temperature: float = typer.Option(
0.0,
Expand Down Expand Up @@ -95,6 +127,13 @@ def main(
help="Show version.",
callback=get_sgpt_version,
),
list_models: bool = typer.Option(
False,
"--list-models",
"-lm",
help="List all available models.",
callback=list_models,
),
chat: str = typer.Option(
None,
help="Follow conversation with id, " 'use "temp" for quick session.',
Expand Down Expand Up @@ -149,6 +188,13 @@ def main(
callback=install_shell_integration,
hidden=True, # Hiding since should be used only once.
),
# --install-completion is overridden because default implementation is too slow
install_completion: bool = typer.Option(
False,
help="Install shell completions (ZSH only)",
callback=install_shell_completion,
hidden=True, # Hiding since should be used only once.
),
install_functions: bool = typer.Option(
False,
help="Install default functions.",
Expand Down Expand Up @@ -268,7 +314,7 @@ def main(


def entry_point() -> None:
typer.run(main)
app()


if __name__ == "__main__":
Expand Down
11 changes: 11 additions & 0 deletions sgpt/completions/_sgpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#compdef sgpt

_sgpt() {
_arguments \
'(-s --shell)'{-s,--shell}'[Generate and execute shell commands]' \
'(-c --code)'{-c,--code}'[Generate only code]' \
'--repl[Start a REPL session]' \
'*:Enter your prompt for GPT:_default'
}

compdef _sgpt sgpt
25 changes: 21 additions & 4 deletions sgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"CHAT_CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4o"),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "auto_select_default_model"),
"DEFAULT_MODEL_GEMINI": os.getenv("DEFAULT_MODEL_GEMINI", "gemini/gemini-2.0-flash"),
"DEFAULT_MODEL_OPENAI": os.getenv("DEFAULT_MODEL_OPENAI", "gpt-4o"),
"DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"),
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
Expand All @@ -32,6 +34,7 @@
"OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"),
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"GEMINI_API_BASE_URL": os.getenv("GEMINI_API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
"USE_LITELLM": os.getenv("USE_LITELLM", "false"),
"SHELL_INTERACTION": os.getenv("SHELL_INTERACTION ", "true"),
Expand All @@ -57,9 +60,23 @@ def __init__(self, config_path: Path, **defaults: Any):
else:
config_path.parent.mkdir(parents=True, exist_ok=True)
# Don't write API key to config file if it is in the environment.
if not defaults.get("OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"):
__api_key = getpass(prompt="Please enter your OpenAI API key: ")
defaults["OPENAI_API_KEY"] = __api_key
__llm_provider = input("Which LLM provider do you want to use? [openai/gemini]: ").lower()
while __llm_provider not in ['openai', 'gemini']:
while True:
__llm_provider = input("Please enter 'openai' or 'gemini': ").lower()
if __llm_provider in ['openai', 'gemini']:
break

defaults['LLM_API_PROVIDER'] = __llm_provider

if __llm_provider == 'openai':
if not defaults.get("OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"):
__api_key = getpass(prompt="Please enter your OpenAI API key: ")
defaults["OPENAI_API_KEY"] = __api_key
else:
if not defaults.get("GEMINI_API_KEY") and not os.getenv("GEMINI_API_KEY"):
__api_key = getpass(prompt="Please enter your Gemini API key: ")
defaults["GEMINI_API_KEY"] = __api_key
super().__init__(**defaults)
self._write()

Expand Down
2 changes: 1 addition & 1 deletion sgpt/handlers/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..config import cfg
from ..role import DefaultRoles, SystemRole
from ..utils import option_callback
from .handler import Handler
from .gemini_handler import Handler

CHAT_CACHE_LENGTH = int(cfg.get("CHAT_CACHE_LENGTH"))
CHAT_CACHE_PATH = Path(cfg.get("CHAT_CACHE_PATH"))
Expand Down
2 changes: 1 addition & 1 deletion sgpt/handlers/default_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ..config import cfg
from ..role import SystemRole
from .handler import Handler
from .gemini_handler import Handler

CHAT_CACHE_LENGTH = int(cfg.get("CHAT_CACHE_LENGTH"))
CHAT_CACHE_PATH = Path(cfg.get("CHAT_CACHE_PATH"))
Expand Down
183 changes: 183 additions & 0 deletions sgpt/handlers/gemini_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import json
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional

import litellm # type: ignore

from ..cache import Cache
from ..config import cfg
from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole

completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None]

provider = cfg.get("LLM_API_PROVIDER")
if provider == "gemini":
base_url = cfg.get("GEMINI_API_BASE_URL")
api_key = cfg.get("GEMINI_API_KEY")
elif provider == "openai":
base_url = cfg.get("API_BASE_URL")
api_key = cfg.get("OPENAI_API_KEY")
else:
raise ValueError(f"Invalid provider: {provider}")

additional_kwargs = {
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
"api_key": api_key,
"base_url": None if base_url == "default" else base_url,
# "base_url": "https://generativelanguage.googleapis.com/v1beta/models/",
}

completion = litellm.completion
litellm.suppress_debug_info = True


def validate_model_with_provider(model: str, provider: str) -> bool:
if provider == 'gemini':
return 'gemini' in model.lower()
if provider == 'openai':
return 'gpt' in model.lower()
return False


class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))

def __init__(self, role: SystemRole, markdown: bool) -> None:
self.role = role

api_base_url = base_url
self.base_url = None if api_base_url == "default" else api_base_url
self.timeout = int(cfg.get("REQUEST_TIMEOUT"))

self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")

@property
def printer(self) -> Printer:
return (
MarkdownPrinter(self.code_theme)
if self.markdown
else TextPrinter(self.color)
)

def make_messages(self, prompt: str) -> List[Dict[str, str]]:
raise NotImplementedError

def handle_function_call(
self,
messages: List[dict[str, Any]],
name: str,
arguments: str,
) -> Generator[str, None, None]:
messages.append(
{
"role": "assistant",
"content": "",
"function_call": {"name": name, "arguments": arguments},
}
)

if messages and messages[-1]["role"] == "assistant":
yield "\n"

dict_args = json.loads(arguments)
joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items())
yield f"> @FunctionCall `{name}({joined_args})` \n\n"

result = get_function(name)(**dict_args)
if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true":
yield f"```text\n{result}\n```\n"
messages.append({"role": "function", "content": result, "name": name})

@cache
def get_completion(
self,
model: str,
temperature: float,
top_p: float,
messages: List[Dict[str, Any]],
functions: Optional[List[Dict[str, str]]],
) -> Generator[str, None, None]:
name = arguments = ""
is_shell_role = self.role.name == DefaultRoles.SHELL.value
is_code_role = self.role.name == DefaultRoles.CODE.value
is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value
if is_shell_role or is_code_role or is_dsc_shell_role:
functions = None

if functions:
additional_kwargs["tool_choice"] = "auto"
additional_kwargs["tools"] = functions
additional_kwargs["parallel_tool_calls"] = False

if model == 'auto_select_default_model':
match provider:
case 'gemini':
model = cfg.get('DEFAULT_MODEL_GEMINI')
case 'openai':
model = cfg.get('DEFAULT_MODEL_OPENAI')
case _:
raise ValueError(f'Invalid provider: {provider}')

assert validate_model_with_provider(model, provider), \
f'Model {model} is not compatible with provider {provider}.'

response = completion(
model=model,
messages=messages,
stream=True,
**additional_kwargs,
)

try:
for chunk in response:
delta = chunk.choices[0].delta

# LiteLLM uses dict instead of Pydantic object like OpenAI does.
tool_calls = delta.get("tool_calls")
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name:
name = tool_call.function.name
if tool_call.function.arguments:
arguments += tool_call.function.arguments
if chunk.choices[0].finish_reason == "tool_calls":
yield from self.handle_function_call(messages, name, arguments)
yield from self.get_completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages,
functions=functions,
caching=False,
)
return

yield delta.content or ""
except KeyboardInterrupt:
response.close()

def handle(
self,
prompt: str,
model: str,
temperature: float,
top_p: float,
caching: bool,
functions: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> str:
disable_stream = cfg.get("DISABLE_STREAMING") == "true"
messages = self.make_messages(prompt.strip())
generator = self.get_completion(
model=model,
temperature=temperature,
top_p=top_p,
messages=messages,
functions=functions,
caching=caching,
**kwargs,
)
return self.printer(generator, not disable_stream)
2 changes: 1 addition & 1 deletion sgpt/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
fi
}
zle -N _sgpt_zsh
bindkey ^l _sgpt_zsh
bindkey ^k _sgpt_zsh
# Shell-GPT integration ZSH v0.2
"""
Loading