Skip to content

Postgres fixes: Concurrency issue, connection leak, and fixes to guard CRUD operations #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@ FROM public.ecr.aws/docker/library/python:3.12-slim

# Accept a build arg for the Guardrails token
# We'll add this to the config using the configure command below
# ARG GUARDRAILS_TOKEN
ARG GUARDRAILS_TOKEN

# Create app directory
WORKDIR /app

# Enable venv
ENV PATH="/opt/venv/bin:$PATH"

# Set the directory for nltk data
ENV NLTK_DATA=/opt/nltk_data

# Set env vars for server
ENV GR_CONFIG_FILE_PATH="sample-config.py"
ENV GR_ENV_FILE=".env"
ENV PORT=8000

# print the version just to verify
RUN python3 --version
# start the virtual environment
RUN python3 -m venv /opt/venv

# Enable venv
ENV PATH="/opt/venv/bin:$PATH"

# Install some utilities; you may not need all of these
RUN apt-get update
RUN apt-get install -y git
# Install some utilities
RUN apt-get update && \
apt-get install -y git pkg-config curl gcc g++ && \
rm -rf /var/lib/apt/lists/*

# Copy the requirements file
COPY requirements*.txt .
Expand All @@ -26,26 +35,27 @@ COPY requirements*.txt .
# If you use Poetry this step might be different
RUN pip install -r requirements-lock.txt

# Set the directory for nltk data
ENV NLTK_DATA=/opt/nltk_data

# Download punkt data
RUN python -m nltk.downloader -d /opt/nltk_data punkt

# Run the Guardrails configure command to create a .guardrailsrc file
# RUN guardrails configure --enable-metrics --enable-remote-inferencing --token $GUARDRAILS_TOKEN
RUN guardrails configure --enable-metrics --enable-remote-inferencing --token $GUARDRAILS_TOKEN

# Install any validators from the hub you want
RUN guardrails hub install hub://guardrails/valid_length
RUN guardrails hub install hub://guardrails/detect_pii --no-install-local-models && \
guardrails hub install hub://guardrails/competitor_check --no-install-local-models

# Fetch AWS RDS cert
RUN curl https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem -o ./global-bundle.pem

# Copy the rest over
# We use a .dockerignore to keep unwanted files exluded
COPY . .

EXPOSE 8000
EXPOSE ${PORT}

# This is our start command; yours might be different.
# The guardrails-api is a standard FastAPI application.
# You can use whatever production server you want that support FastAPI.
# Here we use gunicorn
CMD gunicorn --bind 0.0.0.0:8000 --timeout=90 --workers=2 'guardrails_api.app:create_app(".env", "sample-config.py")'
CMD uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port ${PORT} --timeout-keep-alive=90 --workers=4
16 changes: 8 additions & 8 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ services:
- postgres
guardrails-api:
profiles: ["all", "api"]
image: guardrails-api:latest
build:
context: .
dockerfile: Dockerfile
args:
PORT: "8000"
GUARDRAILS_TOKEN: ${GUARDRAILS_TOKEN:-changeme}
ports:
- "8000:8000"
environment:
# APP_ENVIRONMENT: local
# AWS_PROFILE: dev
# AWS_DEFAULT_REGION: us-east-1
# PGPORT: 5432
# PGDATABASE: postgres
# PGHOST: postgres
# PGUSER: ${PGUSER:-postgres}
# PGPASSWORD: ${PGPASSWORD:-changeme}
PGPORT: 5432
PGDATABASE: postgres
PGHOST: postgres
PGUSER: ${PGUSER:-postgres}
PGPASSWORD: ${PGPASSWORD:-changeme}
Comment on lines +45 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to think we should just have two different Dockfiles and compose files to run with PG vs config.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed 👌🏼

NLTK_DATA: /opt/nltk_data
# OTEL_PYTHON_TRACER_PROVIDER: sdk_tracer_provider
# OTEL_SERVICE_NAME: guardrails-api
Expand All @@ -68,8 +68,8 @@ services:
# OTEL_EXPORTER_OTLP_METRICS_ENDPOINT: http://otel-collector:4317
# OTEL_EXPORTER_OTLP_LOGS_ENDPOINT: http://otel-collector:4317
# OTEL_PYTHON_LOG_FORMAT: "%(msg)s [span_id=%(span_id)s]"
# depends_on:
# - postgres
depends_on:
- postgres
# - otel-collector
opensearch-node1:
profiles: ["all", "otel", "infra"]
Expand Down
25 changes: 19 additions & 6 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@

from starlette.middleware.base import BaseHTTPMiddleware

GR_ENV_FILE = os.environ.get("GR_ENV_FILE", None)
GR_CONFIG_FILE_PATH = os.environ.get("GR_CONFIG_FILE_PATH", None)
PORT = int(os.environ.get("PORT", 8000))

class RequestInfoMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
tracer = trace.get_tracer(__name__)
# Get the current context and attach it to this task
with tracer.start_as_current_span("request_info") as span:
client_ip = request.client.host
client_ip = request.client.host if request.client else None
user_agent = request.headers.get("user-agent", "unknown")
referrer = request.headers.get("referrer", "unknown")
user_id = request.headers.get("x-user-id", "unknown")
Expand All @@ -40,13 +43,15 @@ async def dispatch(self, request: Request, call_next):
context.attach(baggage.set_baggage("organization", organization))
context.attach(baggage.set_baggage("app", app))

span.set_attribute("client.ip", client_ip)
span.set_attribute("http.user_agent", user_agent)
span.set_attribute("http.referrer", referrer)
span.set_attribute("user.id", user_id)
span.set_attribute("organization", organization)
span.set_attribute("app", app)

if client_ip:
span.set_attribute("client.ip", client_ip)

response = await call_next(request)
return response

Expand All @@ -70,9 +75,13 @@ def register_config(config: Optional[str] = None):
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)

return config_file_path

# Support for providing env vars as uvicorn does not support supplying args to create_app
# - Usage: GR_CONFIG_FILE_PATH=config.py GR_ENV_FILE=.env PORT=8080 uvicorn --factory 'guardrails_api.app:create_app' --host 0.0.0.0 --port $PORT --workers 2 --timeout-keep-alive 90
# - Usage: gunicorn -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:$PORT --timeout=90 --workers=2 "guardrails_api.app:create_app(None, None, $PORT)"
Comment on lines +80 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's rough. We'll have to address this in the OSS's cli as well since it calls this.

def create_app(
env: Optional[str] = None, config: Optional[str] = None, port: Optional[int] = None
env: Optional[str] = GR_ENV_FILE, config: Optional[str] = GR_CONFIG_FILE_PATH, port: Optional[int] = PORT
):
trace_server_start_if_enabled()
# used to print user-facing messages during server startup
Expand All @@ -89,12 +98,12 @@ def create_app(
env_file_path = os.path.abspath(env)
load_dotenv(env_file_path, override=True)

set_port = port or os.environ.get("PORT", 8000)
set_port = port or PORT
host = os.environ.get("HOST", "http://localhost")
self_endpoint = os.environ.get("SELF_ENDPOINT", f"{host}:{set_port}")
os.environ["SELF_ENDPOINT"] = self_endpoint

register_config(config)
resolved_config_file_path = register_config(config)

app = FastAPI(openapi_url="")

Expand Down Expand Up @@ -159,6 +168,10 @@ async def value_error_handler(request: Request, exc: ValueError):
)

console.print("")
console.print("Using the following configuration:")
console.print(f"- Guardrails Log Level: {guardrails_log_level}")
console.print(f"- Self Endpoint: {self_endpoint}")
console.print(f"- Config File Path: {resolved_config_file_path} [Provided: {config}]")
console.print(
Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")
)
Expand All @@ -170,4 +183,4 @@ async def value_error_handler(request: Request, exc: ValueError):
import uvicorn

app = create_app()
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=PORT)
2 changes: 2 additions & 0 deletions guardrails_api/classes/http_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ def __init__(
context: str = None,
):
self.status = status
self.status_code = status
self.message = message
self.cause = cause
self.detail = f"{message} :: {cause}" if cause is not None else message
self.fields = fields
self.context = context

Expand Down
150 changes: 80 additions & 70 deletions guardrails_api/clients/pg_guard_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from contextlib import contextmanager
from typing import List, Optional
from guardrails_api_client import Guard as GuardStruct
from guardrails_api.classes.http_error import HttpError
from guardrails_api.clients.guard_client import GuardClient
Expand All @@ -18,48 +19,23 @@ def __init__(self):
self.initialized = True
self.pgClient = PostgresClient()

def get_db(self): # generator for local sessions
@contextmanager
def get_db_context(self):
db = self.pgClient.SessionLocal()
try:
yield db
finally:
db.close()

def get_guard(self, guard_name: str, as_of_date: str = None) -> GuardStruct:
db = next(self.get_db())
latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first()
audit_item = None
if as_of_date is not None:
audit_item = (
db.query(GuardItemAudit)
.filter_by(name=guard_name)
.filter(GuardItemAudit.replaced_on > as_of_date)
.order_by(GuardItemAudit.replaced_on.asc())
.first()
)
guard_item = audit_item if audit_item is not None else latest_guard_item
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
return from_guard_item(guard_item)

def get_guard_item(self, guard_name: str) -> GuardItem:
db = next(self.get_db())
return db.query(GuardItem).filter_by(name=guard_name).first()
# These are only internal utilities and do not start db sessions

def get_guards(self) -> List[GuardStruct]:
db = next(self.get_db())
guard_items = db.query(GuardItem).all()
def util_get_guard_item(self, guard_name: str, db) -> GuardItem:
item = db.query(GuardItem).filter_by(name=guard_name).first()
return item

return [from_guard_item(gi) for gi in guard_items]

def create_guard(self, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
def util_create_guard(self, guard: GuardStruct, db) -> GuardStruct:
guard_item = GuardItem(
name=guard.name,
railspec=guard.to_dict(),
Expand All @@ -69,48 +45,82 @@ def create_guard(self, guard: GuardStruct) -> GuardStruct:
db.add(guard_item)
db.commit()
return from_guard_item(guard_item)

# Below are used directly by Controllers and start db sessions

def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
# guard_item.num_reasks = guard.num_reasks
guard_item.railspec = guard.to_dict()
guard_item.description = guard.description
db.commit()
return from_guard_item(guard_item)
def get_guard(self, guard_name: str, as_of_date: Optional[str] = None) -> GuardStruct:
with self.get_db_context() as db:
latest_guard_item = db.query(GuardItem).filter_by(name=guard_name).first()
audit_item = None
if as_of_date is not None:
audit_item = (
db.query(GuardItemAudit)
.filter_by(name=guard_name)
.filter(GuardItemAudit.replaced_on > as_of_date)
.order_by(GuardItemAudit.replaced_on.asc())
.first()
)
guard_item = audit_item if audit_item is not None else latest_guard_item
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
return from_guard_item(guard_item)

def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is not None:
def get_guards(self) -> List[GuardStruct]:
with self.get_db_context() as db:
guard_items = db.query(GuardItem).all()
return [from_guard_item(gi) for gi in guard_items]

def create_guard(self, guard: GuardStruct) -> GuardStruct:
with self.get_db_context() as db:
return self.util_create_guard(guard, db)

def update_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
with self.get_db_context() as db:
guard_item = self.util_get_guard_item(guard_name, db)
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
# guard_item.num_reasks = guard.num_reasks
guard_item.railspec = guard.to_dict()
guard_item.description = guard.description
# guard_item.num_reasks = guard.num_reasks
db.commit()
return from_guard_item(guard_item)
else:
return self.create_guard(guard)

def upsert_guard(self, guard_name: str, guard: GuardStruct) -> GuardStruct:
with self.get_db_context() as db:
guard_item = self.util_get_guard_item(guard_name, db)
if guard_item is not None:
guard_item.railspec = guard.to_dict()
guard_item.description = guard.description
# guard_item.num_reasks = guard.num_reasks
db.commit()
return from_guard_item(guard_item)
else:
return self.util_create_guard(guard, db)

def delete_guard(self, guard_name: str) -> GuardStruct:
db = next(self.get_db())
guard_item = self.get_guard_item(guard_name)
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
db.delete(guard_item)
db.commit()
guard = from_guard_item(guard_item)
return guard
with self.get_db_context() as db:
guard_item = self.util_get_guard_item(guard_name, db)
if guard_item is None:
raise HttpError(
status=404,
message="NotFound",
cause="A Guard with the name {guard_name} does not exist!".format(
guard_name=guard_name
),
)
db.delete(guard_item)
db.commit()
guard = from_guard_item(guard_item)
return guard
Loading