Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ MCP server to interact with Infrahub

## Requirements

- Python 3.8+
- Python 3.13+
- fastmcp
- infrahub_sdk

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"fastmcp>=2.10.5",
"infrahub-sdk[all]>=1.13.5",
"infrahub-sdk>=1.13.5",
]

[dependency-groups]
Expand Down
59 changes: 52 additions & 7 deletions src/infrahub_mcp_server/branch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,63 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated

from fastmcp import Context, FastMCP
from infrahub_sdk.branch import BranchData
from infrahub_sdk.exceptions import GraphQLError
from mcp.types import ToolAnnotations
from pydantic import Field

from infrahub_mcp_server.utils import MCPResponse, MCPToolStatus, _log_and_return_error

if TYPE_CHECKING:
from infrahub_sdk import InfrahubClient

mcp: FastMCP = FastMCP(name="Infrahub Branch")
mcp: FastMCP = FastMCP(name="Infrahub Branches")


@mcp.tool(
tags=["branches", "create"],
annotations=ToolAnnotations(readOnlyHint=False, idempotentHint=True, destructiveHint=False),
)
async def branch_create(
ctx: Context,
name: Annotated[str, Field(description="Name of the branch to create.")],
sync_with_git: Annotated[bool, Field(default=False, description="Whether to sync the branch with git.")],
) -> MCPResponse[dict[str, str]]:
"""Create a new branch in infrahub.

Parameters:
name: Name of the branch to create.
sync_with_git: Whether to sync the branch with git. Defaults to False.

@mcp.tool
async def branch_create(ctx: Context, name: str, sync_with_git: bool = False) -> dict:
"""Create a new branch in infrahub."""
Returns:
Dictionary with success status and branch details.
"""

client: InfrahubClient = ctx.request_context.lifespan_context.client
branch = await client.branch.create(branch_name=name, sync_with_git=sync_with_git, background_execution=False)
ctx.info(f"Creating branch {name} in Infrahub...")

try:
branch = await client.branch.create(branch_name=name, sync_with_git=sync_with_git, background_execution=False)

except GraphQLError as exc:
return _log_and_return_error(ctx=ctx, error=exc, remediation="Check the branch name or your permissions.")

return MCPResponse(
status=MCPToolStatus.SUCCESS,
data={
"name": branch.name,
"id": branch.id,
},
)


@mcp.tool(tags=["branches", "retrieve"], annotations=ToolAnnotations(readOnlyHint=True))
async def get_branches(ctx: Context) -> MCPResponse[dict[str, BranchData]]:
"""Retrieve all branches from infrahub."""

client: InfrahubClient = ctx.request_context.lifespan_context.client
ctx.info("Fetching all branches from Infrahub...")

branches: dict[str, BranchData] = await client.branch.all()

return {"name": branch.name, "id": branch.id}
return MCPResponse(status=MCPToolStatus.SUCCESS, data=branches)
8 changes: 8 additions & 0 deletions src/infrahub_mcp_server/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
NAMESPACES_INTERNAL = ["Internal", "Profile", "Template"]

schema_attribute_type_mapping = {
"Text": "String",
"Number": "Integer",
"Boolean": "Boolean",
"DateTime": "DateTime",
"Enum": "String",
}
37 changes: 29 additions & 8 deletions src/infrahub_mcp_server/gql.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated, Any

from fastmcp import Context, FastMCP
from mcp.types import ToolAnnotations
from pydantic import Field

from infrahub_mcp_server.utils import MCPResponse, MCPToolStatus

if TYPE_CHECKING:
from infrahub_sdk import InfrahubClient

mcp: FastMCP = FastMCP(name="Infrahub GraphQL")


@mcp.tool
async def get_graphql_schema(ctx: Context) -> str:
"""Retrieve the GraphQL schema from Infrahub"""
@mcp.tool(tags=["schemas", "retrieve"], annotations=ToolAnnotations(readOnlyHint=True))
async def get_graphql_schema(ctx: Context) -> MCPResponse[str]:
"""Retrieve the GraphQL schema from Infrahub

Parameters:
None

Returns:
MCPResponse with the GraphQL schema as a string.
"""
client: InfrahubClient = ctx.request_context.lifespan_context.client
resp = await client._get(url=f"{client.address}/schema.graphql") # noqa: SLF001
return resp.text
return MCPResponse(status=MCPToolStatus.SUCCESS, data=resp.text)


@mcp.tool(tags=["schemas", "retrieve"], annotations=ToolAnnotations(readOnlyHint=False))
async def query_graphql(
ctx: Context, query: Annotated[str, Field(description="GraphQL query to execute.")]
) -> MCPResponse[dict[str, Any]]:
"""Execute a GraphQL query against Infrahub.

Parameters:
query: GraphQL query to execute.

Returns:
MCPResponse with the result of the query.

@mcp.tool
async def query_graphql(ctx: Context, query: str) -> dict:
"""Execute a GraphQL query against Infrahub."""
"""
client: InfrahubClient = ctx.request_context.lifespan_context.client
return await client.execute_graphql(query=query)
230 changes: 230 additions & 0 deletions src/infrahub_mcp_server/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import TYPE_CHECKING, Annotated, Any

from fastmcp import Context, FastMCP
from infrahub_sdk.exceptions import GraphQLError, SchemaNotFoundError
from infrahub_sdk.types import Order
from mcp.types import ToolAnnotations
from pydantic import Field

from infrahub_mcp_server.constants import schema_attribute_type_mapping
from infrahub_mcp_server.utils import MCPResponse, MCPToolStatus, _log_and_return_error, convert_node_to_dict

if TYPE_CHECKING:
from infrahub_sdk.client import InfrahubClient

mcp: FastMCP = FastMCP(name="Infrahub Nodes")


@mcp.tool(tags=["nodes", "retrieve"], annotations=ToolAnnotations(readOnlyHint=True))
async def get_nodes(
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
branch: Annotated[
str | None,
Field(default=None, description="Branch to retrieve the objects from. Defaults to None (uses default branch)."),
],
filters: Annotated[dict[str, Any] | None, Field(default=None, description="Dictionary of filters to apply.")],
partial_match: Annotated[bool, Field(default=False, description="Whether to use partial matching for filters.")],
) -> MCPResponse[list[str]]:
"""Get all objects of a specific kind from Infrahub.

To retrieve the list of available kinds, use the `get_schema_mapping` tool.
To retrieve the list of available filters for a specific kind, use the `get_node_filters` tool.

Parameters:
kind: Kind of the objects to retrieve.
branch: Branch to retrieve the objects from. Defaults to None (uses default branch).
filters: Dictionary of filters to apply.
partial_match: Whether to use partial matching for filters.

Returns:
MCPResponse with success status and objects.

"""
client: InfrahubClient = ctx.request_context.lifespan_context.client
ctx.info(f"Fetching nodes of kind: {kind} with filters: {filters} from Infrahub...")

# Verify if the kind exists in the schema and guide Tool if not
try:
schema = await client.schema.get(kind=kind, branch=branch)
except SchemaNotFoundError:
error_msg = f"Schema not found for kind: {kind}."
remediation_msg = "Use the `get_schema_mapping` tool to list available kinds."
return _log_and_return_error(ctx=ctx, error=error_msg, remediation=remediation_msg)

# TODO: Verify if the filters are valid for the kind and guide Tool if not

try:
if filters:
nodes = await client.filters(
kind=schema.kind,
branch=branch,
partial_match=partial_match,
parallel=True,
order=Order(disable=True),
populate_store=True,
prefetch_relationships=True,
**filters,
)
else:
nodes = await client.all(
kind=schema.kind,
branch=branch,
parallel=True,
order=Order(disable=True),
populate_store=True,
prefetch_relationships=True,
)
except GraphQLError as exc:
return _log_and_return_error(ctx=ctx, error=exc, remediation="Check the provided filters or the kind name.")

# Format the response with serializable data
# serialized_nodes = []
# for node in nodes:
# node_data = await convert_node_to_dict(obj=node, branch=branch)
# serialized_nodes.append(node_data)
serialized_nodes = [obj.display_label for obj in nodes]

# Return the serialized response
ctx.debug(f"Retrieved {len(serialized_nodes)} nodes of kind {kind}")

return MCPResponse(
status=MCPToolStatus.SUCCESS,
data=serialized_nodes,
)


@mcp.tool(tags=["nodes", "filters", "retrieve"], annotations=ToolAnnotations(readOnlyHint=True))
async def get_node_filters(
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
branch: Annotated[
str | None,
Field(default=None, description="Branch to retrieve the objects from. Defaults to None (uses default branch)."),
],
) -> MCPResponse[dict[str, str]]:
"""Retrieve all the available filters for a specific schema node kind.

There's multiple types of filters
attribute filters are in the form attribute__value

relationship filters are in the form relationship__attribute__value
you can find more information on the peer node of the relationship using the `get_schema` tool

Filters that start with parent refer to a related generic schema node.
You can find the type of that related node by inspected the output of the `get_schema` tool.

Parameters:
kind: Kind of the objects to retrieve.
branch: Branch to retrieve the objects from. Defaults to None (uses default branch).

Returns:
MCPResponse with success status and filters.
"""
client: InfrahubClient = ctx.request_context.lifespan_context.client
ctx.info(f"Fetching available filters for kind: {kind} from Infrahub...")

# Verify if the kind exists in the schema and guide Tool if not
try:
schema = await client.schema.get(kind=kind, branch=branch)
except SchemaNotFoundError:
error_msg = f"Schema not found for kind: {kind}."
remediation_msg = "Use the `get_schema_mapping` tool to list available kinds."
return _log_and_return_error(ctx=ctx, error=error_msg, remediation=remediation_msg)

filters = {
f"{attribute.name}__value": schema_attribute_type_mapping.get(attribute.kind, "String")
for attribute in schema.attributes
}

for relationship in schema.relationships:
relationship_schema = await client.schema.get(kind=relationship.peer)
relationship_filters = {
f"{relationship.name}__{attribute.name}__value": schema_attribute_type_mapping.get(attribute.kind, "String")
for attribute in relationship_schema.attributes
}
filters.update(relationship_filters)

return MCPResponse(
status=MCPToolStatus.SUCCESS,
data=filters,
)


@mcp.tool(tags=["nodes", "retrieve"], annotations=ToolAnnotations(readOnlyHint=True))
async def get_related_nodes(
ctx: Context,
kind: Annotated[str, Field(description="Kind of the objects to retrieve.")],
relation: Annotated[str, Field(description="Name of the relation to fetch.")],
filters: Annotated[dict[str, Any] | None, Field(default=None, description="Dictionary of filters to apply.")],
branch: Annotated[
str | None,
Field(default=None, description="Branch to retrieve the objects from. Defaults to None (uses default branch)."),
],
) -> MCPResponse[list[dict[str, Any]]]:
"""Retrieve related nodes by relation name and a kind.

Args:
kind: Kind of the node to fetch.
filters: Filters to apply on the node to fetch.
relation: Name of the relation to fetch.
branch: Branch to fetch the node from. Defaults to None (uses default branch).

Returns:
MCPResponse with success status and objects.

"""
client: InfrahubClient = ctx.request_context.lifespan_context.client
filters = filters or {}
if branch:
ctx.info(f"Fetching nodes related to {kind} with filters {filters} in branch {branch} from Infrahub...")
else:
ctx.info(f"Fetching nodes related to {kind} with filters {filters} from Infrahub...")

try:
node_id = node_hfid = None
if filters.get("ids"):
node_id = filters["ids"][0]
elif filters.get("hfid"):
node_hfid = filters["hfid"]
if node_id:
node = await client.get(
kind=kind,
id=node_id,
branch=branch,
include=[relation],
prefetch_relationships=True,
populate_store=True,
)
elif node_hfid:
node = await client.get(
kind=kind,
hfid=node_hfid,
branch=branch,
include=[relation],
prefetch_relationships=True,
populate_store=True,
)
except Exception as exc: # noqa: BLE001
return _log_and_return_error(exc)

rel = getattr(node, relation, None)
if not rel:
_log_and_return_error(
ctx=ctx,
error=f"Relation '{relation}' not found in kind '{kind}'.",
remediation="Check the schema for the kind to confirm if the relation exists.",
)
peers = [
await convert_node_to_dict(
branch=branch,
obj=peer.peer,
include_id=True,
)
for peer in rel.peers
]

return MCPResponse(
status=MCPToolStatus.SUCCESS,
data=peers,
)
Loading