Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ruff #10

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/pr_qc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Quality Control

on: [pull_request] # Trigger the workflow on push events

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: Run Linter
run: |
ruff check .
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.4
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Installs production dependencies
install:
pip install .;

# Installs development dependencies
install-dev:
pip install ".[dev]";

lint:
ruff check .
ruff format .

lint-fix:
ruff check . --fix
ruff format .

qa:
make install-dev
make lint
44 changes: 16 additions & 28 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,35 @@
from guardrails_simlab_client import simlab_connect, custom_judge, JudgeResult
from litellm import litellm

CONTROL_PLANE_URL = os.getenv("CONTROL_PLANE_URL", "http://gr-threat-tester-prod-ctrl-svc.gr-threat-tester-prod.priv.local:8080")
CONTROL_PLANE_URL = os.getenv(
"CONTROL_PLANE_URL",
"http://gr-threat-tester-prod-ctrl-svc.gr-threat-tester-prod.priv.local:8080",
)


@custom_judge(
risk_name="example",
enable=True,
control_plane_host=CONTROL_PLANE_URL
)
def example_judge(
user_message: str,
bot_response: str
) -> JudgeResult:
print(f"Running example_judge on user_message: {user_message} and bot_response: {bot_response}")
return JudgeResult(
justification="This is a test",
triggered=False
@custom_judge(risk_name="example", enable=True, control_plane_host=CONTROL_PLANE_URL)
def example_judge(user_message: str, bot_response: str) -> JudgeResult:
print(
f"Running example_judge on user_message: {user_message} and bot_response: {bot_response}"
)
return JudgeResult(justification="This is a test", triggered=False)

@simlab_connect(
enable=True,
control_plane_host=CONTROL_PLANE_URL
)

@simlab_connect(enable=True, control_plane_host=CONTROL_PLANE_URL)
def generate_with_huge_llm(messages) -> str:
print(f"Running generate_with_huge_llm: {messages}")
res = litellm.completion(
model="gpt-4o-mini",
messages=messages
)
res = litellm.completion(model="gpt-4o-mini", messages=messages)
return res.choices[0].message.content


print("Loaded example.py")

if __name__ == '__main__':
if __name__ == "__main__":
print("Running example.py")
prompt = "It was the best of times, it was the worst of times."
out = generate_with_huge_llm([{
"role": "user",
"content": prompt
}])
out = generate_with_huge_llm([{"role": "user", "content": prompt}])
# Nothing below here will happen bc above is blocking
print(out)

judge_out = example_judge(prompt, out)
print(judge_out)
print(judge_out)
11 changes: 4 additions & 7 deletions guardrails_simlab_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from guardrails_simlab_client.decorators.llm import tt_webhook_polling_sync
from guardrails_simlab_client.decorators.llm import tt_webhook_polling_sync as simlab_connect
from guardrails_simlab_client.decorators.llm import (
tt_webhook_polling_sync as simlab_connect,
)
from guardrails_simlab_client.decorators.custom_judge import custom_judge
from guardrails_simlab_client.protocols import JudgeResult

__all__ = [
"custom_judge",
"tt_webhook_polling_sync",
"JudgeResult",
"simlab_connect"
]
__all__ = ["custom_judge", "tt_webhook_polling_sync", "JudgeResult", "simlab_connect"]
43 changes: 33 additions & 10 deletions guardrails_simlab_client/decorators/custom_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import requests
from guardrails_simlab_client.env import CONTROL_PLANE_URL, _get_api_key, _get_app_id
from guardrails_simlab_client.protocols import JudgeResult
from guardrails_simlab_client.processors.risk_evaluation_processor import RiskEvaluationProcessor
from guardrails_simlab_client.processors.risk_evaluation_processor import (
RiskEvaluationProcessor,
)

LOGGER = getLogger(__name__)


def custom_judge(
*,
risk_name: str,
Expand All @@ -32,11 +35,14 @@ def custom_judge(
)

def wrap(
fn: Callable[[str, str], JudgeResult]
fn: Callable[[str, str], JudgeResult],
) -> Callable[[str, str], JudgeResult]:
LOGGER.info(f"===> Wrapping function {fn.__name__}")

def wrapped(*args, **kwargs):
LOGGER.info(f"===> Wrapped function called with args: {args}, kwargs: {kwargs}")
LOGGER.info(
f"===> Wrapped function called with args: {args}, kwargs: {kwargs}"
)
if enable:
LOGGER.info("===> Starting processing")
processor.start_processing(fn)
Expand All @@ -52,15 +58,28 @@ def wrapped(*args, **kwargs):
)

if not experiments_response.ok:
LOGGER.info(f"Error fetching experiments: {experiments_response.text}")
raise Exception("Error fetching experiments, task is not healthy")
LOGGER.info(
f"Error fetching experiments: {experiments_response.text}"
)
raise Exception(
"Error fetching experiments, task is not healthy"
)
experiments = experiments_response.json()
LOGGER.info(f"=== Found {len(experiments)} experiments with validation in progress")
LOGGER.info(
f"=== Found {len(experiments)} experiments with validation in progress"
)
# experiments = [{"id": "123"}]
for experiment in experiments:
try:
if not risk_name in experiment.get("source_data",{}).get("evaluation_configuration", {}).keys():
LOGGER.info(f"=== Skipping experiment {experiment['id']} as it does not have risk {risk_name}")
if (
risk_name
not in experiment.get("source_data", {})
.get("evaluation_configuration", {})
.keys()
):
LOGGER.info(
f"=== Skipping experiment {experiment['id']} as it does not have risk {risk_name}"
)
continue
LOGGER.info(
f"=== checking for tests for experiment {experiment['id']}"
Expand All @@ -71,8 +90,12 @@ def wrapped(*args, **kwargs):
)

if not tests_response.ok:
LOGGER.info(f"Error fetching tests: {tests_response.text}")
raise Exception("Error fetching tests, task is not healthy")
LOGGER.info(
f"Error fetching tests: {tests_response.text}"
)
raise Exception(
"Error fetching tests, task is not healthy"
)
tests = tests_response.json()

for test in tests:
Expand Down
70 changes: 50 additions & 20 deletions guardrails_simlab_client/decorators/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable,Optional
from typing import Callable, Optional
from logging import getLogger

import time
Expand All @@ -9,15 +9,26 @@

LOGGER = getLogger(__name__)


def tt_webhook_polling_sync(
enable: bool,
control_plane_host: str = CONTROL_PLANE_URL,
max_workers: Optional[int] = None, # Controls max concurrency
application_id: Optional[str] = None,
throttle_time: Optional[float] = None # Time in seconds to pause between each request to the wrapped function
throttle_time: Optional[
float
] = None, # Time in seconds to pause between each request to the wrapped function
) -> Callable:
LOGGER.info(f"===> Initializing TestProcessor with application_id: {application_id}")
processor = TestProcessor(control_plane_host, max_workers, application_id=application_id, throttle_time=throttle_time)
LOGGER.info(
f"===> Initializing TestProcessor with application_id: {application_id}"
)
processor = TestProcessor(
control_plane_host,
max_workers,
application_id=application_id,
throttle_time=throttle_time,
)

def wrap(fn: Callable[[str, ...], str]) -> Callable:
def wrapped(*args, **kwargs):
if enable:
Expand All @@ -29,42 +40,53 @@ def wrapped(*args, **kwargs):
LOGGER.info("===> Starting...")
try:
connection_tests_url = f"{control_plane_host}/api/connection-tests?status=pending&appId={_get_app_id(application_id)}"
LOGGER.info(f"Fetching connection tests from {connection_tests_url}")
LOGGER.info(
f"Fetching connection tests from {connection_tests_url}"
)
response = requests.get(
connection_tests_url,
headers={"x-api-key": _get_api_key()},
)

if not response.ok:
LOGGER.info(f"Error fetching connection tests: {response.text}", )
raise Exception("Error fetching connection tests, task is not healthy")
LOGGER.info(
f"Error fetching connection tests: {response.text}",
)
raise Exception(
"Error fetching connection tests, task is not healthy"
)
pending_connection_tests = response.json()
for test in pending_connection_tests:
try:
response = fn([{
"role": "user",
"content": test["prompt"]
}])
response = fn(
[{"role": "user", "content": test["prompt"]}]
)
requests.patch(
f"{control_plane_host}/api/connection-tests/{test['id']}?appId={_get_app_id(application_id)}",
json={
"response": response,
"status": "completed",
"executed_by": _get_app_id(application_id),
"completed_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"completed_at": time.strftime(
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
),
},
headers={"x-api-key": _get_api_key()},
)
if throttle_time is not None:
time.sleep(throttle_time)
except Exception as e:
LOGGER.info(f"Error processing connection test: {e}")
LOGGER.info(
f"Error processing connection test: {e}"
)
requests.patch(
f"{control_plane_host}/api/connection-tests/{test['id']}?appId={_get_app_id(application_id)}",
json={
"status": "failed",
"executed_by": _get_app_id(application_id),
"failed_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"failed_at": time.strftime(
"%Y-%m-%dT%H:%M:%SZ", time.gmtime()
),
"error": str(e),
},
headers={"x-api-key": _get_api_key()},
Expand All @@ -85,10 +107,16 @@ def wrapped(*args, **kwargs):
)

if not experiments_response.ok:
LOGGER.info(f"Error fetching experiments: {experiments_response.text}")
raise Exception("Error fetching experiments, task is not healthy")
LOGGER.info(
f"Error fetching experiments: {experiments_response.text}"
)
raise Exception(
"Error fetching experiments, task is not healthy"
)
experiments = experiments_response.json()
LOGGER.info(f"=== Found {len(experiments)} unevaluated experiments")
LOGGER.info(
f"=== Found {len(experiments)} unevaluated experiments"
)
sleep = True

for experiment in experiments:
Expand All @@ -105,7 +133,9 @@ def wrapped(*args, **kwargs):

if not tests_response.ok:
sleep = True
LOGGER.info(f"Error fetching tests: {tests_response.text}")
LOGGER.info(
f"Error fetching tests: {tests_response.text}"
)
continue

tests = tests_response.json()
Expand Down Expand Up @@ -133,7 +163,7 @@ def wrapped(*args, **kwargs):
# If it fails for over 1 minute, raise an exception
if experiement_retries > 20:
raise

if sleep:
LOGGER.info("=== Sleeping for 5 seconds")
time.sleep(5)
Expand Down
10 changes: 6 additions & 4 deletions guardrails_simlab_client/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _get_app_id(application_id: Optional[str] = None) -> str:
raise ValueError("GUARDRAILS_APP_ID is not set!")
return application_id


def _get_api_key() -> str:
home_filepath = os.path.expanduser("~")
guardrails_rc_filepath = os.path.join(home_filepath, ".guardrailsrc")
Expand All @@ -27,11 +28,12 @@ def _get_api_key() -> str:
for line in f:
match = re.match(r"token\s*=\s*(?P<api_key>.+)", line)
if match:
api_key = match.group("api_key").strip()
api_key = match.group("api_key").strip()
break

if not api_key:
raise ValueError("GUARDRAILS_TOKEN environment variable is not set or found in $HOME/.guardrailsrc")
raise ValueError(
"GUARDRAILS_TOKEN environment variable is not set or found in $HOME/.guardrailsrc"
)

return api_key

Loading