Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/databricks-agent/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.db
mlruns*
Empty file.
78 changes: 78 additions & 0 deletions python/databricks-agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import warnings
from typing import Any, Generator

import mlflow
import openai
from databricks.sdk import WorkspaceClient
from mlflow.entities import SpanType
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
ResponsesAgentRequest,
ResponsesAgentResponse,
ResponsesAgentStreamEvent,
output_to_responses_items_stream,
to_chat_completions_input,
)

# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"

# TODO: Update with your system prompt
SYSTEM_PROMPT = """
You are a helpful assistant that provides brief, clear responses.
"""


class SimpleChatAgent(ResponsesAgent):
"""
Simple chat agent that calls an LLM using the Databricks OpenAI client API.

You can replace this with your own agent.
The decorators @mlflow.trace tell MLflow Tracing to track calls to the agent.
"""

def __init__(self):
self.workspace_client = WorkspaceClient()
self.client = self.workspace_client.serving_endpoints.get_open_ai_client()
self.llm_endpoint = LLM_ENDPOINT_NAME
self.SYSTEM_PROMPT = SYSTEM_PROMPT

@mlflow.trace(span_type=SpanType.LLM)
def call_llm(self, messages: list[dict[str, Any]]) -> Generator[dict[str, Any], None, None]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="PydanticSerializationUnexpectedValue")
for chunk in self.client.chat.completions.create(
model=self.llm_endpoint,
messages=to_chat_completions_input(messages),
stream=True,
):
yield chunk.to_dict()

# With autologging, you do not need @mlflow.trace here, but you can add it to override the span type.
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
outputs = [
event.item
for event in self.predict_stream(request)
if event.type == "response.output_item.done"
]
return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

# With autologging, you do not need @mlflow.trace here, but you can add it to override the span type.
def predict_stream(
self, request: ResponsesAgentRequest
) -> Generator[ResponsesAgentStreamEvent, None, None]:
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + [
i.model_dump() for i in request.input
]
yield from output_to_responses_items_stream(chunks=self.call_llm(messages))


mlflow.openai.autolog()
AGENT = SimpleChatAgent()
mlflow.models.set_model(AGENT)

if __name__ == "__main__":
for event in AGENT.predict_stream(
{"input": [{"role": "user", "content": "What is 5+5?"}]}
):
print(event.model_dump(exclude_none=True))
33 changes: 33 additions & 0 deletions python/databricks-agent/deploy_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import mlflow
from databricks import agents
from mlflow.models.resources import DatabricksServingEndpoint

from agent import LLM_ENDPOINT_NAME

# TODO: Replace with your Unity Catalog model name (catalog.schema.model_name)
UC_MODEL_NAME = "workspace.default.databricks_agent"

resources = [
DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
]

mlflow.set_registry_uri("databricks-uc")

with mlflow.start_run():
logged_agent_info = mlflow.pyfunc.log_model(
name="dbos_demo_agent",
python_model="agent.py",
resources=resources,
registered_model_name=UC_MODEL_NAME,
)

print(f"Model URI: {logged_agent_info.model_uri}")
print(f"Model version: {logged_agent_info.registered_model_version}")

deployment = agents.deploy(
model_name=UC_MODEL_NAME,
model_version=logged_agent_info.registered_model_version,
scale_to_zero_enabled=True,
)

print(f"Deployment endpoint: {deployment.query_endpoint}")
11 changes: 11 additions & 0 deletions python/databricks-agent/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[project]
name = "databricks-agent"
version = "0.1.0"
description = "Simple chat agent using Databricks + MLflow ResponsesAgent"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"databricks-agents>=1.9.3",
"databricks-openai>=0.10.0",
"mlflow-skinny[databricks]>=3.9.0",
]
19 changes: 19 additions & 0 deletions python/databricks-agent/query_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys

from databricks.sdk import WorkspaceClient

# TODO: Replace with your deployed agent endpoint name
DEFAULT_ENDPOINT = "default-schema-databricks_agent"

endpoint = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_ENDPOINT

client = WorkspaceClient().serving_endpoints.get_open_ai_client()

response = client.responses.create(
model=endpoint,
input=[{"role": "user", "content": "What is 5+5?"}],
stream=True,
)

for event in response:
print(event)
Loading