Skip to content

Commit 1cd5139

Browse files
authored
Merge pull request #46 from guardrails-ai/nichwch/core-schema-impl
Use new core schemas in guardrails API
2 parents 51e5530 + 797263b commit 1cd5139

23 files changed

+246
-1312
lines changed

.github/workflows/pr_qa.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@ jobs:
1313
name: PR checks
1414
runs-on: ubuntu-latest
1515
steps:
16-
- uses: actions/checkout@v2
17-
- uses: actions/setup-python@v4
16+
- uses: actions/checkout@v4
17+
- uses: actions/setup-python@v5
18+
with:
19+
python-version: '3.12'
1820
- name: Install Dependencies
1921
run: |
2022
python -m venv ./.venv
2123
source ./.venv/bin/activate
2224
make install-lock;
2325
make install-dev;
2426
25-
curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
26-
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml
27+
cp ./.venv/lib/python3.12/site-packages/guardrails_api_client/openapi-spec.json ./open-api-spec.json
2728
2829
- name: Run Quality Checks
2930
run: |

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ test:
5858

5959
test-cov:
6060
coverage run --source=./src -m pytest ./tests
61-
coverage report --fail-under=50
61+
coverage report --fail-under=45
6262

6363
view-test-cov:
6464
coverage run --source=./src -m pytest ./tests

app.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from flask import Flask
3+
from flask.json.provider import DefaultJSONProvider
34
from flask_cors import CORS
45
from werkzeug.middleware.proxy_fix import ProxyFix
56
from urllib.parse import urlparse
@@ -9,6 +10,14 @@
910
from src.otel import otel_is_disabled, initialize
1011

1112

13+
# TODO: Move this to a separate file
14+
class OverrideJsonProvider(DefaultJSONProvider):
15+
def default(self, o):
16+
if isinstance(o, set):
17+
return list(o)
18+
return super().default(self, o)
19+
20+
1221
class ReverseProxied(object):
1322
def __init__(self, app):
1423
self.app = app
@@ -27,6 +36,7 @@ def create_app():
2736
load_dotenv()
2837

2938
app = Flask(__name__)
39+
app.json = OverrideJsonProvider(app)
3040

3141
app.config["APPLICATION_ROOT"] = "/"
3242
app.config["PREFERRED_URL_SCHEME"] = "https"

compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ services:
2323
PGADMIN_DEFAULT_EMAIL: "${PGUSER:-postgres}@guardrails.com"
2424
PGADMIN_DEFAULT_PASSWORD: ${PGPASSWORD:-changeme}
2525
PGADMIN_SERVER_JSON_FILE: /var/lib/pgadmin/servers.json
26+
# FIXME: Copy over server.json file and create passfile
2627
volumes:
2728
- ./pgadmin-data:/var/lib/pgadmin
2829
depends_on:

local.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ export SELF_ENDPOINT=http://localhost:8000
3535
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
3636
export HF_API_KEY=${HF_TOKEN}
3737

38-
39-
curl https://raw.githubusercontent.com/guardrails-ai/guardrails-api-client/main/service-specs/guardrails-service-spec.yml -o ./open-api-spec.yml
40-
npx @redocly/cli bundle --dereferenced --output ./open-api-spec.json --ext json ./open-api-spec.yml
41-
4238
# For running https locally
4339
# mkdir -p ~/certificates
4440
# if [ ! -f ~/certificates/local.key ]; then

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
flask
22
sqlalchemy
33
lxml
4-
guardrails-ai
4+
guardrails-ai @ git+https://github.com/guardrails-ai/guardrails.git@core-schema-impl
55
# Let this come from guardrails-ai as a transient dependency.
66
# Pip confuses tag versions with commit ids,
77
# and claims a conflict even though it's the same thing.

sample-config.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
'''
1010

1111
from guardrails import Guard
12-
from guardrails.hub import RegexMatch, RestrictToTopic
12+
from guardrails.hub import RegexMatch, ValidChoices, ValidLength #, RestrictToTopic
1313

1414
name_case = Guard(
1515
name='name-case',
1616
description='Checks that a string is in Name Case format.'
1717
).use(
18-
RegexMatch(regex="^[A-Z][a-z\\s]*$")
18+
RegexMatch(regex="^(?:[A-Z][^\s]*\s?)+$")
1919
)
2020

2121
all_caps = Guard(
@@ -25,31 +25,47 @@
2525
RegexMatch(regex="^[A-Z\\s]*$")
2626
)
2727

28-
valid_topics = ["music", "cooking", "camping", "outdoors"]
29-
invalid_topics = ["sports", "work", "ai"]
30-
all_topics = [*valid_topics, *invalid_topics]
31-
32-
def custom_llm (text: str, *args, **kwargs):
33-
return [
34-
{
35-
"name": t,
36-
"present": (t in text),
37-
"confidence": 5
38-
}
39-
for t in all_topics
40-
]
41-
42-
custom_code_guard = Guard(
43-
name='custom',
44-
description='Uses a custom llm for RestrictToTopic'
28+
lower_case = Guard(
29+
name='lower-case',
30+
description='Checks that a string is all lowercase.'
31+
).use(
32+
RegexMatch(regex="^[a-z\\s]*$")
33+
).use(
34+
ValidLength(1, 100)
4535
).use(
46-
RestrictToTopic(
47-
valid_topics=valid_topics,
48-
invalid_topics=invalid_topics,
49-
llm_callable=custom_llm,
50-
disable_classifier=True,
51-
disable_llm=False,
52-
# Pass this so it doesn't load the bart model
53-
classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
54-
)
55-
)
36+
ValidChoices(["music", "cooking", "camping", "outdoors"])
37+
)
38+
39+
print(lower_case.to_json())
40+
41+
42+
43+
44+
# valid_topics = ["music", "cooking", "camping", "outdoors"]
45+
# invalid_topics = ["sports", "work", "ai"]
46+
# all_topics = [*valid_topics, *invalid_topics]
47+
48+
# def custom_llm (text: str, *args, **kwargs):
49+
# return [
50+
# {
51+
# "name": t,
52+
# "present": (t in text),
53+
# "confidence": 5
54+
# }
55+
# for t in all_topics
56+
# ]
57+
58+
# custom_code_guard = Guard(
59+
# name='custom',
60+
# description='Uses a custom llm for RestrictToTopic'
61+
# ).use(
62+
# RestrictToTopic(
63+
# valid_topics=valid_topics,
64+
# invalid_topics=invalid_topics,
65+
# llm_callable=custom_llm,
66+
# disable_classifier=True,
67+
# disable_llm=False,
68+
# # Pass this so it doesn't load the bart model
69+
# classifier_api_endpoint="https://m-1e7af27102f54c3a9eb9cb11aa4715bd-m.default.model-v2.inferless.com/v2/models/RestrictToTopic_1e7af27102f54c3a9eb9cb11aa4715bd/versions/1/infer",
70+
# )
71+
# )

src/blueprints/guards.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
from guardrails import Guard
99
from guardrails.classes import ValidationOutcome
1010
from opentelemetry.trace import Span
11-
from src.classes.guard_struct import GuardStruct
1211
from src.classes.http_error import HttpError
13-
from src.classes.validation_output import ValidationOutput
1412
from src.clients.memory_guard_client import MemoryGuardClient
1513
from src.clients.pg_guard_client import PGGuardClient
1614
from src.clients.postgres_client import postgres_is_enabled
1715
from src.utils.handle_error import handle_error
1816
from src.utils.get_llm_callable import get_llm_callable
19-
from src.utils.prep_environment import cleanup_environment, prep_environment
17+
from guardrails_api_client import Guard as GuardStruct
2018

2119

2220
guards_bp = Blueprint("guards", __name__, url_prefix="/guards")
@@ -43,9 +41,7 @@
4341
def guards():
4442
if request.method == "GET":
4543
guards = guard_client.get_guards()
46-
if len(guards) > 0 and (isinstance(guards[0], Guard)):
47-
return [g._to_request() for g in guards]
48-
return [g.to_response() for g in guards]
44+
return [g.to_dict() for g in guards]
4945
elif request.method == "POST":
5046
if not postgres_is_enabled():
5147
raise HttpError(
@@ -54,11 +50,9 @@ def guards():
5450
"POST /guards is not implemented for in-memory guards.",
5551
)
5652
payload = request.json
57-
guard = GuardStruct.from_request(payload)
53+
guard = GuardStruct.from_dict(payload)
5854
new_guard = guard_client.create_guard(guard)
59-
if isinstance(new_guard, Guard):
60-
return new_guard._to_request()
61-
return new_guard.to_response()
55+
return new_guard.to_dict()
6256
else:
6357
raise HttpError(
6458
405,
@@ -83,9 +77,7 @@ def guard(guard_name: str):
8377
guard_name=decoded_guard_name
8478
),
8579
)
86-
if isinstance(guard, Guard):
87-
return guard._to_request()
88-
return guard.to_response()
80+
return guard.to_dict()
8981
elif request.method == "PUT":
9082
if not postgres_is_enabled():
9183
raise HttpError(
@@ -94,11 +86,9 @@ def guard(guard_name: str):
9486
"PUT /<guard_name> is not implemented for in-memory guards.",
9587
)
9688
payload = request.json
97-
guard = GuardStruct.from_request(payload)
89+
guard = GuardStruct.from_dict(payload)
9890
updated_guard = guard_client.upsert_guard(decoded_guard_name, guard)
99-
if isinstance(updated_guard, Guard):
100-
return updated_guard._to_request()
101-
return updated_guard.to_response()
91+
return updated_guard.to_dict()
10292
elif request.method == "DELETE":
10393
if not postgres_is_enabled():
10494
raise HttpError(
@@ -107,9 +97,7 @@ def guard(guard_name: str):
10797
"DELETE /<guard_name> is not implemented for in-memory guards.",
10898
)
10999
guard = guard_client.delete_guard(decoded_guard_name)
110-
if isinstance(guard, Guard):
111-
return guard._to_request()
112-
return guard.to_response()
100+
return guard.to_dict()
113101
else:
114102
raise HttpError(
115103
405,
@@ -123,7 +111,7 @@ def collect_telemetry(
123111
*,
124112
guard: Guard,
125113
validate_span: Span,
126-
validation_output: ValidationOutput,
114+
validation_output: ValidationOutcome,
127115
prompt_params: Dict[str, Any],
128116
result: ValidationOutcome,
129117
):
@@ -179,12 +167,9 @@ def validate(guard_name: str):
179167
)
180168
decoded_guard_name = unquote_plus(guard_name)
181169
guard_struct = guard_client.get_guard(decoded_guard_name)
182-
if isinstance(guard_struct, GuardStruct):
183-
# TODO: is there a way to do this with Guard?
184-
prep_environment(guard_struct)
185170

186171
llm_output = payload.pop("llmOutput", None)
187-
num_reasks = payload.pop("numReasks", guard_struct.num_reasks)
172+
num_reasks = payload.pop("numReasks", None)
188173
prompt_params = payload.pop("promptParams", {})
189174
llm_api = payload.pop("llmApi", None)
190175
args = payload.pop("args", [])
@@ -199,11 +184,10 @@ def validate(guard_name: str):
199184
# f"validate-{decoded_guard_name}"
200185
# ) as validate_span:
201186
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
202-
guard: Guard = Guard()
203-
if isinstance(guard_struct, GuardStruct):
204-
guard: Guard = guard_struct.to_guard(openai_api_key)
205-
elif isinstance(guard_struct, Guard):
206-
guard = guard_struct
187+
guard = guard_struct
188+
if not isinstance(guard_struct, Guard):
189+
guard: Guard = Guard.from_dict(guard_struct.to_dict())
190+
207191
# validate_span.set_attribute("guardName", decoded_guard_name)
208192
if llm_api is not None:
209193
llm_api = get_llm_callable(llm_api)
@@ -234,22 +218,20 @@ def validate(guard_name: str):
234218
message="BadRequest",
235219
cause="Streaming is not supported for parse calls!",
236220
)
237-
238221
result: ValidationOutcome = guard.parse(
239222
llm_output=llm_output,
240223
num_reasks=num_reasks,
241224
prompt_params=prompt_params,
242225
llm_api=llm_api,
243226
# api_key=openai_api_key,
244-
*args,
245227
**payload,
246228
)
247229
else:
248230
if stream:
249231

250232
def guard_streamer():
251233
guard_stream = guard(
252-
llm_api=llm_api,
234+
# llm_api=llm_api,
253235
prompt_params=prompt_params,
254236
num_reasks=num_reasks,
255237
stream=stream,
@@ -260,7 +242,7 @@ def guard_streamer():
260242

261243
for result in guard_stream:
262244
# TODO: Just make this a ValidationOutcome with history
263-
validation_output: ValidationOutput = ValidationOutput(
245+
validation_output: ValidationOutcome = ValidationOutcome(
264246
result.validation_passed,
265247
result.validated_output,
266248
guard.history,
@@ -278,11 +260,11 @@ def validate_streamer(guard_iter):
278260
fragment = json.dumps(validation_output.to_response())
279261
yield f"{fragment}\n"
280262

281-
final_validation_output: ValidationOutput = ValidationOutput(
282-
next_result.validation_passed,
283-
next_result.validated_output,
284-
guard.history,
285-
next_result.raw_llm_output,
263+
final_validation_output: ValidationOutcome = ValidationOutcome(
264+
validation_passed=next_result.validation_passed,
265+
validated_output=next_result.validated_output,
266+
history=guard.history,
267+
raw_llm_output=next_result.raw_llm_output,
286268
)
287269
# I don't know if these are actually making it to OpenSearch
288270
# because the span may be ended already
@@ -293,7 +275,7 @@ def validate_streamer(guard_iter):
293275
# prompt_params=prompt_params,
294276
# result=next_result
295277
# )
296-
final_output_json = json.dumps(final_validation_output.to_response())
278+
final_output_json = final_validation_output.to_json()
297279
yield f"{final_output_json}\n"
298280

299281
return Response(
@@ -312,12 +294,12 @@ def validate_streamer(guard_iter):
312294
)
313295

314296
# TODO: Just make this a ValidationOutcome with history
315-
validation_output = ValidationOutput(
316-
result.validation_passed,
317-
result.validated_output,
318-
guard.history,
319-
result.raw_llm_output,
320-
)
297+
# validation_output = ValidationOutcome(
298+
# validation_passed = result.validation_passed,
299+
# validated_output=result.validated_output,
300+
# history=guard.history,
301+
# raw_llm_output=result.raw_llm_output,
302+
# )
321303

322304
# collect_telemetry(
323305
# guard=guard,
@@ -326,6 +308,4 @@ def validate_streamer(guard_iter):
326308
# prompt_params=prompt_params,
327309
# result=result
328310
# )
329-
if isinstance(guard_struct, GuardStruct):
330-
cleanup_environment(guard_struct)
331-
return validation_output.to_response()
311+
return result.to_dict()

0 commit comments

Comments
 (0)