Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions .github/workflows/integration-vector-io-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ jobs:

- name: Build Llama Stack
run: |
uv run --no-sync llama stack build --template ci-tests --image-type venv
uv run --no-sync llama stack build --distro starter --image-type venv --single-provider "vector_io=${{ matrix.vector-io-provider }}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, should we stick with ci-tests distro similar to other integration tests? Or is there a benefit to starter here


- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
Expand All @@ -154,24 +154,23 @@ jobs:

- name: Run Vector IO Integration Tests
env:
ENABLE_CHROMADB: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'true' || '' }}
# Set environment variables based on provider
MILVUS_URL: ${{ matrix.vector-io-provider == 'inline::milvus' && 'dummy' || '' }}
CHROMADB_URL: ${{ matrix.vector-io-provider == 'remote::chromadb' && 'http://localhost:8000' || '' }}
ENABLE_PGVECTOR: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'true' || '' }}
PGVECTOR_HOST: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'localhost' || '' }}
PGVECTOR_PORT: ${{ matrix.vector-io-provider == 'remote::pgvector' && '5432' || '' }}
PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }}
ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }}
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }}
FAISS_URL: ${{ matrix.vector-io-provider == 'inline::faiss' && 'dummy' || '' }}
SQLITE_VEC_URL: ${{ matrix.vector-io-provider == 'inline::sqlite-vec' && 'dummy' || '' }}
run: |
echo "Testing provider: ${{ matrix.vector-io-provider }}"
echo "Environment variables set for this provider"

uv run --no-sync \
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \
--embedding-model inline::sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
--embedding-dimension 768
pytest -sv --stack-config ~/.llama/distributions/starter/starter-filtered-run.yaml \
tests/integration/vector_io

- name: Check Storage and Memory Available After Tests
if: ${{ always() }}
Expand Down
102 changes: 97 additions & 5 deletions llama_stack/cli/stack/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,84 @@
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"


def _apply_single_provider_filter(build_config: BuildConfig, single_provider_arg: str) -> BuildConfig:
"""Filter a distribution to only include specified providers for certain APIs."""
provider_filters: dict[str, str] = {}
for api_provider in single_provider_arg.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--single-provider`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
provider_filters[api] = provider_type

# Create a copy of the build config to modify
filtered_build_config = BuildConfig(
image_type=build_config.image_type,
image_name=build_config.image_name,
external_providers_dir=build_config.external_providers_dir,
external_apis_dir=build_config.external_apis_dir,
distribution_spec=DistributionSpec(
providers={},
description=build_config.distribution_spec.description,
),
)

# Copy all providers, but filter the specified APIs
for api, providers in build_config.distribution_spec.providers.items():
if api in provider_filters:
target_provider_type = provider_filters[api]
filtered_providers = [p for p in providers if p.provider_type == target_provider_type]
if not filtered_providers:
cprint(
f"Provider {target_provider_type} not found in distribution for API {api}",
color="red",
file=sys.stderr,
)
sys.exit(1)
filtered_build_config.distribution_spec.providers[api] = filtered_providers
else:
# Keep all providers for unfiltered APIs
filtered_build_config.distribution_spec.providers[api] = providers

return filtered_build_config


def _generate_filtered_run_config(
build_config: BuildConfig,
build_dir: Path,
distro_name: str,
) -> Path:
"""
Generate a filtered run.yaml by starting with the original distribution's run.yaml
and filtering the providers according to the build_config.
"""
# Load the original distribution's run.yaml
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"

with importlib.resources.as_file(distro_resource) as path:
with open(path) as f:
original_config = yaml.safe_load(f)

# Apply provider filtering to the loaded config
for api, providers in build_config.distribution_spec.providers.items():
if api in original_config.get("providers", {}):
# Filter this API to only include the providers from build_config
provider_types = {p.provider_type for p in providers}
filtered_providers = [p for p in original_config["providers"][api] if p["provider_type"] in provider_types]
original_config["providers"][api] = filtered_providers

# Write the filtered run config
run_config_file = build_dir / f"{distro_name}-filtered-run.yaml"
with open(run_config_file, "w") as f:
yaml.dump(original_config, f, sort_keys=False)

return run_config_file


@lru_cache
def available_distros_specs() -> dict[str, BuildConfig]:
import yaml
Expand Down Expand Up @@ -93,6 +171,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
build_config = available_distros[distro_name]

# Apply single-provider filtering if specified
if args.single_provider:
build_config = _apply_single_provider_filter(build_config, args.single_provider)

if args.image_type:
build_config.image_type = args.image_type
else:
Expand Down Expand Up @@ -245,6 +328,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
image_name=image_name,
config_path=args.config,
distro_name=distro_name,
is_filtered=bool(args.single_provider),
)

except (Exception, RuntimeError) as exc:
Expand Down Expand Up @@ -363,6 +447,7 @@ def _run_stack_build_command_from_build_config(
image_name: str | None = None,
distro_name: str | None = None,
config_path: str | None = None,
is_filtered: bool = False,
) -> Path | Traversable:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
Expand Down Expand Up @@ -435,12 +520,19 @@ def _run_stack_build_command_from_build_config(
raise RuntimeError(f"Failed to build image {image_name}")

if distro_name:
# copy run.yaml from distribution to build_dir instead of generating it again
distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
run_config_file = build_dir / f"{distro_name}-run.yaml"
# If single-provider filtering was applied, generate a filtered run config
# Otherwise, copy run.yaml from distribution as before
if is_filtered:
run_config_file = _generate_filtered_run_config(build_config, build_dir, distro_name)
distro_path = run_config_file # Use the generated file as the distro_path
else:
# copy run.yaml from distribution to build_dir instead of generating it again
distro_resource = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml"
run_config_file = build_dir / f"{distro_name}-run.yaml"

with importlib.resources.as_file(distro_path) as path:
shutil.copy(path, run_config_file)
with importlib.resources.as_file(distro_resource) as path:
shutil.copy(path, run_config_file)
distro_path = run_config_file # Update distro_path to point to the copied file

cprint("Build Successful!", color="green", file=sys.stderr)
cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr)
Expand Down
7 changes: 7 additions & 0 deletions llama_stack/cli/stack/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ def _add_arguments(self):
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)

self.parser.add_argument(
"--single-provider",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice. I will migrate this to list-deps too

type=str,
default=None,
help="Limit a distribution to a single provider for specific APIs. Format: api1=provider1,api2=provider2. Use with --distro to filter an existing distribution.",
)

def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distributions/ci-tests/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::weaviate
- provider_type: remote::qdrant
files:
- provider_type: inline::localfs
safety:
Expand Down
15 changes: 15 additions & 0 deletions llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ providers:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/weaviate_registry.db
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/qdrant_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distributions/starter-gpu/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::weaviate
- provider_type: remote::qdrant
files:
- provider_type: inline::localfs
safety:
Expand Down
15 changes: 15 additions & 0 deletions llama_stack/distributions/starter-gpu/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ providers:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/weaviate_registry.db
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/qdrant_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distributions/starter/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ distribution_spec:
- provider_type: inline::milvus
- provider_type: remote::chromadb
- provider_type: remote::pgvector
- provider_type: remote::weaviate
- provider_type: remote::qdrant
files:
- provider_type: inline::localfs
safety:
Expand Down
15 changes: 15 additions & 0 deletions llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ providers:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
provider_type: remote::weaviate
config:
weaviate_api_key: null
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/weaviate_registry.db
- provider_id: ${env.QDRANT_URL:+qdrant}
provider_type: remote::qdrant
config:
api_key: ${env.QDRANT_API_KEY:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/qdrant_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
Expand Down
14 changes: 14 additions & 0 deletions llama_stack/distributions/starter/starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig


Expand Down Expand Up @@ -113,6 +115,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
BuildProvider(provider_type="inline::milvus"),
BuildProvider(provider_type="remote::chromadb"),
BuildProvider(provider_type="remote::pgvector"),
BuildProvider(provider_type="remote::weaviate"),
BuildProvider(provider_type="remote::qdrant"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [
Expand Down Expand Up @@ -221,6 +225,16 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
password="${env.PGVECTOR_PASSWORD:=}",
),
),
Provider(
provider_id="${env.WEAVIATE_CLUSTER_URL:+weaviate}",
provider_type="remote::weaviate",
config=WeaviateVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.QDRANT_URL:+qdrant}",
provider_type="remote::qdrant",
config=QdrantVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
],
"files": [files_provider],
},
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/utils/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class JobStatus(Enum):
completed = "completed"


type JobID = str
type JobType = str
JobID = str
JobType = str


class JobArtifact(BaseModel):
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ class Setup(BaseModel):
"text_model": "groq/llama-3.3-70b-versatile",
},
),
"milvus": Setup(
name="milvus",
description="Milvus vector database provider for vector_io tests",
env={
"MILVUS_URL": "dummy",
},
),
"chromadb": Setup(
name="chromadb",
description="ChromaDB vector database provider for vector_io tests",
env={
"CHROMADB_URL": "http://localhost:8000",
},
),
"pgvector": Setup(
name="pgvector",
description="PGVector database provider for vector_io tests",
env={
"PGVECTOR_DB": "llama_stack_test",
"PGVECTOR_USER": "postgres",
"PGVECTOR_PASSWORD": "password",
},
),
}


Expand All @@ -179,4 +202,9 @@ class Setup(BaseModel):
roots=["tests/integration/inference/test_vision_inference.py"],
default_setup="ollama-vision",
),
"vector_io": Suite(
name="vector_io",
roots=["tests/integration/vector_io"],
default_setup="milvus",
),
}
Loading
Loading