8
8
from guardrails import Guard
9
9
from guardrails .classes import ValidationOutcome
10
10
from opentelemetry .trace import Span
11
- from src .classes .guard_struct import GuardStruct
12
11
from src .classes .http_error import HttpError
13
- from src .classes .validation_output import ValidationOutput
14
12
from src .clients .memory_guard_client import MemoryGuardClient
15
13
from src .clients .pg_guard_client import PGGuardClient
16
14
from src .clients .postgres_client import postgres_is_enabled
17
15
from src .utils .handle_error import handle_error
18
16
from src .utils .get_llm_callable import get_llm_callable
19
- from src . utils . prep_environment import cleanup_environment , prep_environment
17
+ from guardrails_api_client import Guard as GuardStruct
20
18
21
19
22
20
guards_bp = Blueprint ("guards" , __name__ , url_prefix = "/guards" )
43
41
def guards ():
44
42
if request .method == "GET" :
45
43
guards = guard_client .get_guards ()
46
- if len (guards ) > 0 and (isinstance (guards [0 ], Guard )):
47
- return [g ._to_request () for g in guards ]
48
- return [g .to_response () for g in guards ]
44
+ return [g .to_dict () for g in guards ]
49
45
elif request .method == "POST" :
50
46
if not postgres_is_enabled ():
51
47
raise HttpError (
@@ -54,11 +50,9 @@ def guards():
54
50
"POST /guards is not implemented for in-memory guards." ,
55
51
)
56
52
payload = request .json
57
- guard = GuardStruct .from_request (payload )
53
+ guard = GuardStruct .from_dict (payload )
58
54
new_guard = guard_client .create_guard (guard )
59
- if isinstance (new_guard , Guard ):
60
- return new_guard ._to_request ()
61
- return new_guard .to_response ()
55
+ return new_guard .to_dict ()
62
56
else :
63
57
raise HttpError (
64
58
405 ,
@@ -83,9 +77,7 @@ def guard(guard_name: str):
83
77
guard_name = decoded_guard_name
84
78
),
85
79
)
86
- if isinstance (guard , Guard ):
87
- return guard ._to_request ()
88
- return guard .to_response ()
80
+ return guard .to_dict ()
89
81
elif request .method == "PUT" :
90
82
if not postgres_is_enabled ():
91
83
raise HttpError (
@@ -94,11 +86,9 @@ def guard(guard_name: str):
94
86
"PUT /<guard_name> is not implemented for in-memory guards." ,
95
87
)
96
88
payload = request .json
97
- guard = GuardStruct .from_request (payload )
89
+ guard = GuardStruct .from_dict (payload )
98
90
updated_guard = guard_client .upsert_guard (decoded_guard_name , guard )
99
- if isinstance (updated_guard , Guard ):
100
- return updated_guard ._to_request ()
101
- return updated_guard .to_response ()
91
+ return updated_guard .to_dict ()
102
92
elif request .method == "DELETE" :
103
93
if not postgres_is_enabled ():
104
94
raise HttpError (
@@ -107,9 +97,7 @@ def guard(guard_name: str):
107
97
"DELETE /<guard_name> is not implemented for in-memory guards." ,
108
98
)
109
99
guard = guard_client .delete_guard (decoded_guard_name )
110
- if isinstance (guard , Guard ):
111
- return guard ._to_request ()
112
- return guard .to_response ()
100
+ return guard .to_dict ()
113
101
else :
114
102
raise HttpError (
115
103
405 ,
@@ -123,7 +111,7 @@ def collect_telemetry(
123
111
* ,
124
112
guard : Guard ,
125
113
validate_span : Span ,
126
- validation_output : ValidationOutput ,
114
+ validation_output : ValidationOutcome ,
127
115
prompt_params : Dict [str , Any ],
128
116
result : ValidationOutcome ,
129
117
):
@@ -179,12 +167,9 @@ def validate(guard_name: str):
179
167
)
180
168
decoded_guard_name = unquote_plus (guard_name )
181
169
guard_struct = guard_client .get_guard (decoded_guard_name )
182
- if isinstance (guard_struct , GuardStruct ):
183
- # TODO: is there a way to do this with Guard?
184
- prep_environment (guard_struct )
185
170
186
171
llm_output = payload .pop ("llmOutput" , None )
187
- num_reasks = payload .pop ("numReasks" , guard_struct . num_reasks )
172
+ num_reasks = payload .pop ("numReasks" , None )
188
173
prompt_params = payload .pop ("promptParams" , {})
189
174
llm_api = payload .pop ("llmApi" , None )
190
175
args = payload .pop ("args" , [])
@@ -199,11 +184,10 @@ def validate(guard_name: str):
199
184
# f"validate-{decoded_guard_name}"
200
185
# ) as validate_span:
201
186
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)
202
- guard : Guard = Guard ()
203
- if isinstance (guard_struct , GuardStruct ):
204
- guard : Guard = guard_struct .to_guard (openai_api_key )
205
- elif isinstance (guard_struct , Guard ):
206
- guard = guard_struct
187
+ guard = guard_struct
188
+ if not isinstance (guard_struct , Guard ):
189
+ guard : Guard = Guard .from_dict (guard_struct .to_dict ())
190
+
207
191
# validate_span.set_attribute("guardName", decoded_guard_name)
208
192
if llm_api is not None :
209
193
llm_api = get_llm_callable (llm_api )
@@ -234,22 +218,20 @@ def validate(guard_name: str):
234
218
message = "BadRequest" ,
235
219
cause = "Streaming is not supported for parse calls!" ,
236
220
)
237
-
238
221
result : ValidationOutcome = guard .parse (
239
222
llm_output = llm_output ,
240
223
num_reasks = num_reasks ,
241
224
prompt_params = prompt_params ,
242
225
llm_api = llm_api ,
243
226
# api_key=openai_api_key,
244
- * args ,
245
227
** payload ,
246
228
)
247
229
else :
248
230
if stream :
249
231
250
232
def guard_streamer ():
251
233
guard_stream = guard (
252
- llm_api = llm_api ,
234
+ # llm_api=llm_api,
253
235
prompt_params = prompt_params ,
254
236
num_reasks = num_reasks ,
255
237
stream = stream ,
@@ -260,7 +242,7 @@ def guard_streamer():
260
242
261
243
for result in guard_stream :
262
244
# TODO: Just make this a ValidationOutcome with history
263
- validation_output : ValidationOutput = ValidationOutput (
245
+ validation_output : ValidationOutcome = ValidationOutcome (
264
246
result .validation_passed ,
265
247
result .validated_output ,
266
248
guard .history ,
@@ -278,11 +260,11 @@ def validate_streamer(guard_iter):
278
260
fragment = json .dumps (validation_output .to_response ())
279
261
yield f"{ fragment } \n "
280
262
281
- final_validation_output : ValidationOutput = ValidationOutput (
282
- next_result .validation_passed ,
283
- next_result .validated_output ,
284
- guard .history ,
285
- next_result .raw_llm_output ,
263
+ final_validation_output : ValidationOutcome = ValidationOutcome (
264
+ validation_passed = next_result .validation_passed ,
265
+ validated_output = next_result .validated_output ,
266
+ history = guard .history ,
267
+ raw_llm_output = next_result .raw_llm_output ,
286
268
)
287
269
# I don't know if these are actually making it to OpenSearch
288
270
# because the span may be ended already
@@ -293,7 +275,7 @@ def validate_streamer(guard_iter):
293
275
# prompt_params=prompt_params,
294
276
# result=next_result
295
277
# )
296
- final_output_json = json . dumps ( final_validation_output .to_response () )
278
+ final_output_json = final_validation_output .to_json ( )
297
279
yield f"{ final_output_json } \n "
298
280
299
281
return Response (
@@ -312,12 +294,12 @@ def validate_streamer(guard_iter):
312
294
)
313
295
314
296
# TODO: Just make this a ValidationOutcome with history
315
- validation_output = ValidationOutput (
316
- result .validation_passed ,
317
- result .validated_output ,
318
- guard .history ,
319
- result .raw_llm_output ,
320
- )
297
+ # validation_output = ValidationOutcome (
298
+ # validation_passed = result.validation_passed,
299
+ # validated_output= result.validated_output,
300
+ # history= guard.history,
301
+ # raw_llm_output= result.raw_llm_output,
302
+ # )
321
303
322
304
# collect_telemetry(
323
305
# guard=guard,
@@ -326,6 +308,4 @@ def validate_streamer(guard_iter):
326
308
# prompt_params=prompt_params,
327
309
# result=result
328
310
# )
329
- if isinstance (guard_struct , GuardStruct ):
330
- cleanup_environment (guard_struct )
331
- return validation_output .to_response ()
311
+ return result .to_dict ()
0 commit comments