Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/features/prompt-tools.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# LLM Prompt Tools

InvokeAI includes two built-in tools that use local language models to help you write better prompts. Both tools appear as small buttons in the top-right corner of the positive prompt area and are only visible when you have a compatible model installed.

## Expand Prompt

Takes your short prompt and expands it into a detailed, vivid description suitable for image generation.

**How to use:**

1. Type a brief prompt (e.g. "a cat in a garden")
2. Click the sparkle button in the prompt area
3. Select a Text LLM model from the dropdown
4. Click **Expand**
5. Your prompt is replaced with the expanded version

**Compatible models:** Any HuggingFace model with a `ForCausalLM` architecture. Recommended options:

| Model | Size | HuggingFace ID |
|-------|------|----------------|
| Qwen2.5 1.5B Instruct | ~3 GB | `Qwen/Qwen2.5-1.5B-Instruct` |
| Phi-3 Mini Instruct | ~7.5 GB | `microsoft/Phi-3-mini-4k-instruct` |
| TinyLlama Chat | ~2 GB | `TinyLlama/TinyLlama-1.1B-Chat-v1.0` |

Install by pasting the HuggingFace ID into the Model Manager. The model is automatically detected as a **Text LLM** type.

## Image to Prompt

Upload an image and generate a descriptive prompt from it using a vision-language model.

**How to use:**

1. Click the image button in the prompt area
2. Select a LLaVA OneVision model from the dropdown
3. Click **Upload Image** and select an image
4. Click **Generate Prompt**
5. The generated description is set as your prompt

**Compatible models:** LLaVA OneVision models (already supported by InvokeAI).

## Undo

Both tools overwrite your current prompt. You can undo this change:

- Press **Ctrl+Z** (or **Cmd+Z** on macOS) in the prompt textarea within 30 seconds
- The undo state is cleared when you start typing manually

## Workflow Node

A **Text LLM** node is also available in the workflow editor for use in automated pipelines. It accepts a prompt string and model selection as inputs and outputs the expanded text as a string.
173 changes: 171 additions & 2 deletions invokeai/app/api/routers/utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
import asyncio
import logging
from pathlib import Path
from typing import Optional, Union

import torch
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from fastapi import Body
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pyparsing import ParseException
from transformers import AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.backend.llava_onevision_pipeline import LlavaOnevisionPipeline
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline
from invokeai.backend.util.devices import TorchDevice

logger = logging.getLogger(__name__)

utilities_router = APIRouter(prefix="/v1/utilities", tags=["utilities"])

Expand Down Expand Up @@ -42,3 +56,158 @@ async def parse_dynamicprompts(
prompts = [prompt]
error = str(e)
return DynamicPromptsResponse(prompts=prompts if prompts else [""], error=error)


# --- Expand Prompt ---


class ExpandPromptRequest(BaseModel):
prompt: str
model_key: str
max_tokens: int = Field(default=300, ge=1, le=2048)
system_prompt: str | None = None


class ExpandPromptResponse(BaseModel):
expanded_prompt: str
error: str | None = None


def _resolve_model_path(model_config_path: str) -> Path:
"""Resolve a model config path to an absolute path."""
model_path = Path(model_config_path)
if model_path.is_absolute():
return model_path.resolve()
base_models_path = ApiDependencies.invoker.services.configuration.models_path
return (base_models_path / model_path).resolve()


def _run_expand_prompt(prompt: str, model_key: str, max_tokens: int, system_prompt: str | None) -> str:
"""Run text LLM inference synchronously (called from thread)."""
model_manager = ApiDependencies.invoker.services.model_manager
model_config = model_manager.store.get_model(model_key)

if model_config.type != ModelType.TextLLM:
raise ValueError(f"Model '{model_key}' is not a TextLLM model (got {model_config.type})")

loaded_model = model_manager.load.load_model(model_config)

with loaded_model.model_on_device() as (_, model):
model_abs_path = _resolve_model_path(model_config.path)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)

pipeline = TextLLMPipeline(model, tokenizer)
model_device = next(model.parameters()).device
output = pipeline.run(
prompt=prompt,
system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT,
max_new_tokens=max_tokens,
device=model_device,
dtype=TorchDevice.choose_torch_dtype(),
)

return output


@utilities_router.post(
"/expand-prompt",
operation_id="expand_prompt",
responses={
200: {"model": ExpandPromptResponse},
},
)
async def expand_prompt(body: ExpandPromptRequest) -> ExpandPromptResponse:
"""Expand a brief prompt into a detailed image generation prompt using a text LLM."""
try:
with torch.no_grad():
expanded = await asyncio.to_thread(
_run_expand_prompt,
body.prompt,
body.model_key,
body.max_tokens,
body.system_prompt,
)
return ExpandPromptResponse(expanded_prompt=expanded)
except UnknownModelException:
raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found")
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Error expanding prompt: {e}")
raise HTTPException(status_code=500, detail=str(e))


# --- Image to Prompt ---


class ImageToPromptRequest(BaseModel):
image_name: str
model_key: str
instruction: str = "Describe this image in detail for use as an AI image generation prompt."


class ImageToPromptResponse(BaseModel):
prompt: str
error: str | None = None


def _run_image_to_prompt(image_name: str, model_key: str, instruction: str) -> str:
"""Run LLaVA OneVision inference synchronously (called from thread)."""
model_manager = ApiDependencies.invoker.services.model_manager
model_config = model_manager.store.get_model(model_key)

if model_config.type != ModelType.LlavaOnevision:
raise ValueError(f"Model '{model_key}' is not a LLaVA OneVision model (got {model_config.type})")

loaded_model = model_manager.load.load_model(model_config)

# Load the image from InvokeAI's image store
image = ApiDependencies.invoker.services.images.get_pil_image(image_name)
image = image.convert("RGB")

with loaded_model.model_on_device() as (_, model):
if not isinstance(model, LlavaOnevisionForConditionalGeneration):
raise TypeError(f"Expected LlavaOnevisionForConditionalGeneration, got {type(model).__name__}")

model_abs_path = _resolve_model_path(model_config.path)
processor = AutoProcessor.from_pretrained(model_abs_path, local_files_only=True)
if not isinstance(processor, LlavaOnevisionProcessor):
raise TypeError(f"Expected LlavaOnevisionProcessor, got {type(processor).__name__}")

pipeline = LlavaOnevisionPipeline(model, processor)
model_device = next(model.parameters()).device
output = pipeline.run(
prompt=instruction,
images=[image],
device=model_device,
dtype=TorchDevice.choose_torch_dtype(),
)

return output


@utilities_router.post(
"/image-to-prompt",
operation_id="image_to_prompt",
responses={
200: {"model": ImageToPromptResponse},
},
)
async def image_to_prompt(body: ImageToPromptRequest) -> ImageToPromptResponse:
"""Generate a descriptive prompt from an image using a vision-language model."""
try:
with torch.no_grad():
prompt = await asyncio.to_thread(
_run_image_to_prompt,
body.image_name,
body.model_key,
body.instruction,
)
return ImageToPromptResponse(prompt=prompt)
except UnknownModelException:
raise HTTPException(status_code=404, detail=f"Model '{body.model_key}' not found")
except (ValueError, TypeError) as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
logger.error(f"Error generating prompt from image: {e}")
raise HTTPException(status_code=500, detail=str(e))
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class FieldDescriptions:
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
text_llm_model = "The text language model to use for text generation"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"

Expand Down
65 changes: 65 additions & 0 deletions invokeai/app/invocations/text_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from transformers import AutoTokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, InputField, UIComponent
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import ModelType
from invokeai.backend.text_llm_pipeline import DEFAULT_SYSTEM_PROMPT, TextLLMPipeline
from invokeai.backend.util.devices import TorchDevice


@invocation(
"text_llm",
title="Text LLM",
tags=["llm", "text", "prompt"],
category="llm",
version="1.0.0",
classification=Classification.Beta,
)
class TextLLMInvocation(BaseInvocation):
"""Run a text language model to generate or expand text (e.g. for prompt expansion)."""

prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
system_prompt: str = InputField(
default=DEFAULT_SYSTEM_PROMPT,
description="System prompt that guides the model's behavior.",
ui_component=UIComponent.Textarea,
)
text_llm_model: ModelIdentifierField = InputField(
title="Text LLM Model",
description=FieldDescriptions.text_llm_model,
ui_model_type=ModelType.TextLLM,
)
max_tokens: int = InputField(
default=300,
ge=1,
le=2048,
description="Maximum number of tokens to generate.",
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
model_config = context.models.get_config(self.text_llm_model)

with context.models.load(self.text_llm_model).model_on_device() as (_, model):
model_abs_path = context.models.get_absolute_path(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_abs_path, local_files_only=True)

pipeline = TextLLMPipeline(model, tokenizer)
model_device = next(model.parameters()).device
output = pipeline.run(
prompt=self.prompt,
system_prompt=self.system_prompt,
max_new_tokens=self.max_tokens,
device=model_device,
dtype=TorchDevice.choose_torch_dtype(),
)

return StringOutput(value=output)
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
T2IAdapter_Diffusers_SDXL_Config,
)
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
from invokeai.backend.model_manager.configs.text_llm import TextLLM_Diffusers_Config
from invokeai.backend.model_manager.configs.textual_inversion import (
TI_File_SD1_Config,
TI_File_SD2_Config,
Expand Down Expand Up @@ -256,6 +257,7 @@
Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()],
Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()],
Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()],
Annotated[TextLLM_Diffusers_Config, TextLLM_Diffusers_Config.get_tag()],
# Unknown model (fallback)
Annotated[Unknown_Config, Unknown_Config.get_tag()],
],
Expand Down
44 changes: 44 additions & 0 deletions invokeai/backend/model_manager/configs/text_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import (
Literal,
Self,
)

from pydantic import Field
from typing_extensions import Any

from invokeai.backend.model_manager.configs.base import Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
common_config_paths,
get_class_name_from_config_dict_or_raise,
raise_for_override_fields,
raise_if_not_dir,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelType,
)


class TextLLM_Diffusers_Config(Diffusers_Config_Base, Config_Base):
"""Model config for text-only causal language models (e.g. Llama, Phi, Qwen, Mistral)."""

type: Literal[ModelType.TextLLM] = Field(default=ModelType.TextLLM)
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)

raise_for_override_fields(cls, override_fields)

# Check that the model's architecture is a causal language model.
# This covers LlamaForCausalLM, PhiForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM,
# MistralForCausalLM, GemmaForCausalLM, GPTNeoXForCausalLM, etc.
class_name = get_class_name_from_config_dict_or_raise(common_config_paths(mod.path))
if not class_name.endswith("ForCausalLM"):
raise NotAMatchError(f"model architecture '{class_name}' is not a causal language model")

return cls(**override_fields)
Loading
Loading