Skip to content

Commit a315118

Browse files
committed
internal/llmapp: allow policy checker in overview functions
Add an optional policy checker to the overviews client. When a policy checker is configured, all LLM inputs and outputs will be checked for safety against the configured policy. Not yet used by Gaby or anywhere else. For #70 Change-Id: I8d48048eae9651499ec937a8804ab554baca2316 Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637977 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
1 parent 158f50b commit a315118

4 files changed

Lines changed: 165 additions & 8 deletions

File tree

internal/llmapp/check.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package llmapp
6+
7+
import (
8+
"context"
9+
"log/slog"
10+
11+
"golang.org/x/oscar/internal/llm"
12+
"golang.org/x/oscar/internal/storage"
13+
)
14+
15+
// NewWithChecker is like [New], but it configures the Client to use
16+
// the given checker to check the inputs to and outputs of the LLM against
17+
// safety policies.
18+
//
19+
// When any of the Overview functions are called, the prompts and outputs of the LLM
20+
// will be checked for safety violations.
21+
func NewWithChecker(lg *slog.Logger, g llm.ContentGenerator, checker llm.PolicyChecker, db storage.DB) *Client {
22+
return &Client{slog: lg, g: g, checker: checker, db: db}
23+
}
24+
25+
// hasPolicyViolation invokes the policy checker on the given prompts and LLM output and
26+
// logs its results. It reports whether any policy violations were found.
27+
// TODO(tatianabradley): Cache calls to policy checker.
28+
func (c *Client) hasPolicyViolation(ctx context.Context, prompts []llm.Part, output string) bool {
29+
if c.checker == nil {
30+
return false
31+
}
32+
foundViolation := false
33+
for _, p := range prompts {
34+
switch v := p.(type) {
35+
case llm.Text:
36+
if c.logCheck(ctx, string(v), nil) {
37+
foundViolation = true
38+
}
39+
default:
40+
// Other types are not supported for checks yet.
41+
c.slog.Info("llmapp: can't check policy for prompt part (unsupported type)", "prompt part", v)
42+
}
43+
}
44+
if c.logCheck(ctx, output, prompts) {
45+
return true
46+
}
47+
return foundViolation
48+
}
49+
50+
// logCheck invokes the policy checker on the give text (with optional prompts)
51+
// and logs its results.
52+
// It reports whether any policy violations were found.
53+
func (c *Client) logCheck(ctx context.Context, text string, prompts []llm.Part) bool {
54+
prs, err := c.checker.CheckText(ctx, text, prompts...)
55+
if err != nil {
56+
c.slog.Error("llmapp: error checking for policy violations", "err", err)
57+
return false
58+
}
59+
c.slog.Info("llmapp: found policy results", "text", text, "prompts", prompts, "results", toStrings(prs))
60+
if vs := violations(prs); len(vs) > 0 {
61+
c.slog.Warn("llmapp: found policy violations for LLM output", "text", text, "prompts", prompts, "violations", toStrings(vs))
62+
return true
63+
}
64+
return false
65+
}
66+
67+
func toStrings(prs []*llm.PolicyResult) []string {
68+
var ss []string
69+
for _, pr := range prs {
70+
ss = append(ss, pr.String())
71+
}
72+
return ss
73+
}
74+
75+
// violations returns the policies in prs that are in violation.
76+
func violations(prs []*llm.PolicyResult) []*llm.PolicyResult {
77+
var vs []*llm.PolicyResult
78+
for _, pr := range prs {
79+
if pr.IsViolative() {
80+
vs = append(vs, pr)
81+
}
82+
}
83+
return vs
84+
}

internal/llmapp/check_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package llmapp
6+
7+
import (
8+
"context"
9+
"strings"
10+
"testing"
11+
12+
"golang.org/x/oscar/internal/llm"
13+
"golang.org/x/oscar/internal/storage"
14+
"golang.org/x/oscar/internal/testutil"
15+
)
16+
17+
func TestWithChecker(t *testing.T) {
18+
lg := testutil.Slogger(t)
19+
g := llm.EchoContentGenerator()
20+
db := storage.MemDB()
21+
checker := badChecker{}
22+
c := NewWithChecker(lg, g, checker, db)
23+
24+
// With violation.
25+
doc1 := &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some bad text"}
26+
doc2 := &Doc{Text: "some good text 2"}
27+
r, err := c.Overview(context.Background(), doc1, doc2)
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
if !r.HasPolicyViolation {
32+
t.Errorf("c.Overview.HasPolicyViolation = false, want true")
33+
}
34+
35+
// Without violation.
36+
r, err = c.Overview(context.Background(), doc2)
37+
if err != nil {
38+
t.Fatal(err)
39+
}
40+
if r.HasPolicyViolation {
41+
t.Errorf("c.Overview.HasPolicyViolation = true, want false")
42+
}
43+
}
44+
45+
// badChecker is a test implementation of [llm.PolicyChecker] that
46+
// always returns a policy violation for text containing the string "bad",
47+
// and no violations otherwise.
48+
type badChecker struct{}
49+
50+
// no-op
51+
func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {}
52+
53+
// return violation for text containing "bad" and no violation for any other text.
54+
func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) {
55+
if strings.Contains(text, "bad") {
56+
return []*llm.PolicyResult{
57+
{
58+
PolicyType: llm.PolicyTypeDangerousContent,
59+
ViolationResult: llm.ViolationResultViolative,
60+
},
61+
}, nil
62+
}
63+
return []*llm.PolicyResult{
64+
{
65+
PolicyType: llm.PolicyTypeDangerousContent,
66+
ViolationResult: llm.ViolationResultNonViolative,
67+
},
68+
}, nil
69+
}

internal/llmapp/data.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,6 @@ type Result struct {
2727
Cached bool // whether the response was cached
2828
Schema *llm.Schema // the JSON schema used to generate the result (nil if none)
2929
Prompt []llm.Part // the prompt(s) used to generate the result
30+
// TODO(tatianabradley): Store the specific policy results instead of just a boolean.
31+
HasPolicyViolation bool // whether any policy violations were found for the inputs or outputs of the LLM
3032
}

internal/llmapp/overview.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,17 @@ import (
3333

3434
// Client is a client for accessing the LLM application functionality.
3535
type Client struct {
36-
slog *slog.Logger
37-
g llm.ContentGenerator
38-
db storage.DB // cache for LLM responses
36+
slog *slog.Logger
37+
g llm.ContentGenerator
38+
checker llm.PolicyChecker
39+
db storage.DB // cache for LLM responses
3940
}
4041

4142
// New returns a new client.
4243
// g is the underlying LLM content generator to use, and db is the database
4344
// to use as a cache.
4445
func New(lg *slog.Logger, g llm.ContentGenerator, db storage.DB) *Client {
45-
return &Client{slog: lg, g: g, db: db}
46+
return NewWithChecker(lg, g, nil, db)
4647
}
4748

4849
// Overview returns an LLM-generated overview of the given documents,
@@ -101,10 +102,11 @@ func (c *Client) overview(ctx context.Context, kind docsKind, groups ...*docGrou
101102
return nil, err
102103
}
103104
return &Result{
104-
Response: overview,
105-
Cached: cached,
106-
Schema: schema,
107-
Prompt: prompt,
105+
Response: overview,
106+
Cached: cached,
107+
Schema: schema,
108+
Prompt: prompt,
109+
HasPolicyViolation: c.hasPolicyViolation(ctx, prompt, overview),
108110
}, nil
109111
}
110112

0 commit comments

Comments
 (0)