Skip to content

Commit

Permalink
Merge pull request #46 from guardrails-ai/nichwch/core-schema-impl
Browse files Browse the repository at this point in the history
Use new core schemas in guardrails API
  • Loading branch information
CalebCourier authored Jun 14, 2024
2 parents 51e5530 + 797263b commit 1cd5139
Show file tree
Hide file tree
Showing 23 changed files with 246 additions and 1,312 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/pr_qa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ jobs:
name: PR checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install Dependencies
run: |
python -m venv ./.venv
source ./.venv/bin/activate
make install-lock;
make install-dev;
curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml
cp ./.venv/lib/python3.12/site-packages/guardrails_api_client/openapi-spec.json ./open-api-spec.json
- name: Run Quality Checks
run: |
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ test:

test-cov:
coverage run --source=./src -m pytest ./tests
coverage report --fail-under=50
coverage report --fail-under=45

view-test-cov:
coverage run --source=./src -m pytest ./tests
Expand Down
10 changes: 10 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from flask import Flask
from flask.json.provider import DefaultJSONProvider
from flask_cors import CORS
from werkzeug.middleware.proxy_fix import ProxyFix
from urllib.parse import urlparse
Expand All @@ -9,6 +10,14 @@
from src.otel import otel_is_disabled, initialize


# TODO: Move this to a separate file
class OverrideJsonProvider(DefaultJSONProvider):
def default(self, o):
if isinstance(o, set):
return list(o)
return super().default(self, o)


class ReverseProxied(object):
def __init__(self, app):
self.app = app
Expand All @@ -27,6 +36,7 @@ def create_app():
load_dotenv()

app = Flask(__name__)
app.json = OverrideJsonProvider(app)

app.config["APPLICATION_ROOT"] = "/"
app.config["PREFERRED_URL_SCHEME"] = "https"
Expand Down
1 change: 1 addition & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ services:
PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com"
PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme}
PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json
# FIXME: Copy over server.json file and create passfile
volumes:
- ./pgadmin-data:/var/lib/pgadmin
depends_on:
Expand Down
4 changes: 0 additions & 4 deletions local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ export SELF_ENDPOINT=http://localhost:8000
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
export HF_API_KEY=${HF_TOKEN}


curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml

# For running https locally
# mkdir -p ~/certificates
# if [ ! -f ~/certificates/local.key ]; then
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
flask
sqlalchemy
lxml
guardrails-ai
guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@core-schema-impl
# Let this come from guardrails-ai as a transient dependency.
# Pip confuses tag versions with commit ids,
# and claims a conflict even though it's the same thing.
Expand Down
74 changes: 45 additions & 29 deletions sample-config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
'''

from guardrails import Guard
from guardrails.hub import RegexMatch, RestrictToTopic
from guardrails.hub import RegexMatch, ValidChoices, ValidLength #, RestrictToTopic

name_case = Guard(
name='name-case',
description='Checks that a string is in Name Case format.'
).use(
RegexMatch(regex="^[A-Z][a-z\\s]*$")
RegexMatch(regex="^(?:[A-Z][^\s]*\s?)+$")
)

all_caps = Guard(
Expand All @@ -25,31 +25,47 @@
RegexMatch(regex="^[A-Z\\s]*$")
)

valid_topics = ["music", "cooking", "camping", "outdoors"]
invalid_topics = ["sports", "work", "ai"]
all_topics = [*valid_topics, *invalid_topics]

def custom_llm (text: str, *args, **kwargs):
return [
{
"name": t,
"present": (t in text),
"confidence": 5
}
for t in all_topics
]

custom_code_guard = Guard(
name='custom',
description='Uses a custom llm for RestrictToTopic'
lower_case = Guard(
name='lower-case',
description='Checks that a string is all lowercase.'
).use(
RegexMatch(regex="^[a-z\\s]*$")
).use(
ValidLength(1, 100)
).use(
RestrictToTopic(
valid_topics=valid_topics,
invalid_topics=invalid_topics,
llm_callable=custom_llm,
disable_classifier=True,
disable_llm=False,
# Pass this so it doesn't load the bart model
classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
)
)
ValidChoices(["music", "cooking", "camping", "outdoors"])
)

print(lower_case.to_json())




# valid_topics = ["music", "cooking", "camping", "outdoors"]
# invalid_topics = ["sports", "work", "ai"]
# all_topics = [*valid_topics, *invalid_topics]

# def custom_llm (text: str, *args, **kwargs):
# return [
# {
# "name": t,
# "present": (t in text),
# "confidence": 5
# }
# for t in all_topics
# ]

# custom_code_guard = Guard(
# name='custom',
# description='Uses a custom llm for RestrictToTopic'
# ).use(
# RestrictToTopic(
# valid_topics=valid_topics,
# invalid_topics=invalid_topics,
# llm_callable=custom_llm,
# disable_classifier=True,
# disable_llm=False,
# # Pass this so it doesn't load the bart model
# classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
# )
# )
78 changes: 29 additions & 49 deletions src/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from guardrails import Guard
from guardrails.classes import ValidationOutcome
from opentelemetry.trace import Span
from src.classes.guard_struct import GuardStruct
from src.classes.http_error import HttpError
from src.classes.validation_output import ValidationOutput
from src.clients.memory_guard_client import MemoryGuardClient
from src.clients.pg_guard_client import PGGuardClient
from src.clients.postgres_client import postgres_is_enabled
from src.utils.handle_error import handle_error
from src.utils.get_llm_callable import get_llm_callable
from src.utils.prep_environment import cleanup_environment, prep_environment
from guardrails_api_client import Guard as GuardStruct


guards_bp = Blueprint("guards", __name__, url_prefix="/guards")
Expand All @@ -43,9 +41,7 @@
def guards():
if request.method == "GET":
guards = guard_client.get_guards()
if len(guards) > 0 and (isinstance(guards[0], Guard)):
return [g._to_request() for g in guards]
return [g.to_response() for g in guards]
return [g.to_dict() for g in guards]
elif request.method == "POST":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -54,11 +50,9 @@ def guards():
"POST /guards is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_dict(payload)
new_guard = guard_client.create_guard(guard)
if isinstance(new_guard, Guard):
return new_guard._to_request()
return new_guard.to_response()
return new_guard.to_dict()
else:
raise HttpError(
405,
Expand All @@ -83,9 +77,7 @@ def guard(guard_name: str):
guard_name=decoded_guard_name
),
)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_dict()
elif request.method == "PUT":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -94,11 +86,9 @@ def guard(guard_name: str):
"PUT /<guard_name> is not implemented for in-memory guards.",
)
payload = request.json
guard = GuardStruct.from_request(payload)
guard = GuardStruct.from_dict(payload)
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
if isinstance(updated_guard, Guard):
return updated_guard._to_request()
return updated_guard.to_response()
return updated_guard.to_dict()
elif request.method == "DELETE":
if not postgres_is_enabled():
raise HttpError(
Expand All @@ -107,9 +97,7 @@ def guard(guard_name: str):
"DELETE /<guard_name> is not implemented for in-memory guards.",
)
guard = guard_client.delete_guard(decoded_guard_name)
if isinstance(guard, Guard):
return guard._to_request()
return guard.to_response()
return guard.to_dict()
else:
raise HttpError(
405,
Expand All @@ -123,7 +111,7 @@ def collect_telemetry(
*,
guard: Guard,
validate_span: Span,
validation_output: ValidationOutput,
validation_output: ValidationOutcome,
prompt_params: Dict[str, Any],
result: ValidationOutcome,
):
Expand Down Expand Up @@ -179,12 +167,9 @@ def validate(guard_name: str):
)
decoded_guard_name = unquote_plus(guard_name)
guard_struct = guard_client.get_guard(decoded_guard_name)
if isinstance(guard_struct, GuardStruct):
# TODO: is there a way to do this with Guard?
prep_environment(guard_struct)

llm_output = payload.pop("llmOutput", None)
num_reasks = payload.pop("numReasks", guard_struct.num_reasks)
num_reasks = payload.pop("numReasks", None)
prompt_params = payload.pop("promptParams", {})
llm_api = payload.pop("llmApi", None)
args = payload.pop("args", [])
Expand All @@ -199,11 +184,10 @@ def validate(guard_name: str):
# f"validate-{decoded_guard_name}"
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
guard: Guard = Guard()
if isinstance(guard_struct, GuardStruct):
guard: Guard = guard_struct.to_guard(openai_api_key)
elif isinstance(guard_struct, Guard):
guard = guard_struct
guard = guard_struct
if not isinstance(guard_struct, Guard):
guard: Guard = Guard.from_dict(guard_struct.to_dict())

# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
llm_api = get_llm_callable(llm_api)
Expand Down Expand Up @@ -234,22 +218,20 @@ def validate(guard_name: str):
message="BadRequest",
cause="Streaming is not supported for parse calls!",
)

result: ValidationOutcome = guard.parse(
llm_output=llm_output,
num_reasks=num_reasks,
prompt_params=prompt_params,
llm_api=llm_api,
# api_key=openai_api_key,
*args,
**payload,
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
# llm_api=llm_api,
prompt_params=prompt_params,
num_reasks=num_reasks,
stream=stream,
Expand All @@ -260,7 +242,7 @@ def guard_streamer():

for result in guard_stream:
# TODO: Just make this a ValidationOutcome with history
validation_output: ValidationOutput = ValidationOutput(
validation_output: ValidationOutcome = ValidationOutcome(
result.validation_passed,
result.validated_output,
guard.history,
Expand All @@ -278,11 +260,11 @@ def validate_streamer(guard_iter):
fragment = json.dumps(validation_output.to_response())
yield f"{fragment}\n"

final_validation_output: ValidationOutput = ValidationOutput(
next_result.validation_passed,
next_result.validated_output,
guard.history,
next_result.raw_llm_output,
final_validation_output: ValidationOutcome = ValidationOutcome(
validation_passed=next_result.validation_passed,
validated_output=next_result.validated_output,
history=guard.history,
raw_llm_output=next_result.raw_llm_output,
)
# I don't know if these are actually making it to OpenSearch
# because the span may be ended already
Expand All @@ -293,7 +275,7 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=next_result
# )
final_output_json = json.dumps(final_validation_output.to_response())
final_output_json = final_validation_output.to_json()
yield f"{final_output_json}\n"

return Response(
Expand All @@ -312,12 +294,12 @@ def validate_streamer(guard_iter):
)

# TODO: Just make this a ValidationOutcome with history
validation_output = ValidationOutput(
result.validation_passed,
result.validated_output,
guard.history,
result.raw_llm_output,
)
# validation_output = ValidationOutcome(
# validation_passed = result.validation_passed,
# validated_output=result.validated_output,
# history=guard.history,
# raw_llm_output=result.raw_llm_output,
# )

# collect_telemetry(
# guard=guard,
Expand All @@ -326,6 +308,4 @@ def validate_streamer(guard_iter):
# prompt_params=prompt_params,
# result=result
# )
if isinstance(guard_struct, GuardStruct):
cleanup_environment(guard_struct)
return validation_output.to_response()
return result.to_dict()
Loading

0 comments on commit 1cd5139

Please sign in to comment.