diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ee77f3c..b9b3fbd 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -11,7 +11,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: '1.18' + go-version: '1.19' - name: Run vet run: | go vet . @@ -19,6 +19,5 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: latest - # # Run testing on the code - # - name: Run testing - # run: cd test && go test -v + - name: Run testing + run: go test -v diff --git a/.golangci.yml b/.golangci.yml index bdf66a8..ee39860 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -188,7 +188,7 @@ linters: - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with f at the end - gosec # Inspects source code for security problems - - lll # Reports long lines + # - lll # Reports long lines - makezero # Finds slice declarations with non-zero initial length # - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements @@ -205,7 +205,7 @@ linters: - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - stylecheck # Stylecheck is a replacement for golint - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - testpackage # linter that makes you use a separate _test package + # - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters diff --git a/README.md b/README.md index b02976d..01dc530 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,14 @@ -# go-gpt3 -[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3) -[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3) +# openai +[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/fabiustech/openai) +[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/fabiustech/openai) - -[OpenAI GPT-3](https://beta.openai.com/) API wrapper for Go +Zero dependency Go Client for [OpenAI](https://beta.openai.com/) API endpoints. Built upon the great work done [here](https://github.com/sashabaranov/go-gpt3). Installation: ``` -go get github.com/sashabaranov/go-gpt3 +go get github.com/fabiustech/openai ``` - Example usage: ```go @@ -19,22 +17,23 @@ package main import ( "context" "fmt" - gogpt "github.com/sashabaranov/go-gpt3" + + "github.com/fabiustech/openai" + "github.com/fabiustech/openai/models" ) func main() { - c := gogpt.NewClient("your token") - ctx := context.Background() - - req := gogpt.CompletionRequest{ - Model: gogpt.GPT3Ada, - MaxTokens: 5, + var c = openai.NewClient("your token") + + var resp, err = c.CreateCompletion(context.Background(), &openai.CompletionRequest{ + Model: models.TextDavinci003, + MaxTokens: 100, Prompt: "Lorem ipsum", - } - resp, err := c.CreateCompletion(ctx, req) + }) if err != nil { return } + fmt.Println(resp.Choices[0].Text) } ``` diff --git a/answers.go b/answers.go deleted file mode 100644 index 3a20f2a..0000000 --- a/answers.go +++ /dev/null @@ -1,51 +0,0 @@ -package gogpt - -import ( - "bytes" - "context" - "encoding/json" - "net/http" -) - -type AnswerRequest struct { - Documents []string `json:"documents,omitempty"` - File string `json:"file,omitempty"` - Question string `json:"question"` - SearchModel string `json:"search_model,omitempty"` - Model string `json:"model"` - ExamplesContext string `json:"examples_context"` - Examples [][]string `json:"examples"` - MaxTokens int `json:"max_tokens,omitempty"` - Stop []string `json:"stop,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` -} - -type AnswerResponse struct { - Answers []string `json:"answers"` - Completion string `json:"completion"` - Model string `json:"model"` - Object string `json:"object"` - SearchModel string `json:"search_model"` - SelectedDocuments []struct { - Document int `json:"document"` - Text string `json:"text"` - } `json:"selected_documents"` -} - -// Search — perform a semantic search api call over a list of documents. -func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - req, err := http.NewRequest("POST", c.fullURL("/answers"), bytes.NewBuffer(reqBytes)) - if err != nil { - return - } - - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return -} diff --git a/api.go b/api.go deleted file mode 100644 index c339afe..0000000 --- a/api.go +++ /dev/null @@ -1,85 +0,0 @@ -package gogpt - -import ( - "encoding/json" - "fmt" - "net/http" -) - -const apiURLv1 = "https://api.openai.com/v1" - -func newTransport() *http.Client { - return &http.Client{} -} - -// Client is OpenAI GPT-3 API client. -type Client struct { - BaseURL string - HTTPClient *http.Client - authToken string - idOrg string -} - -// NewClient creates new OpenAI API client. -func NewClient(authToken string) *Client { - return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: "", - } -} - -// NewOrgClient creates new OpenAI API client for specified Organization ID. -func NewOrgClient(authToken, org string) *Client { - return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: org, - } -} - -func (c *Client) sendRequest(req *http.Request, v interface{}) error { - req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) - - // Check whether Content-Type is already set, Upload Files API requires - // Content-Type == multipart/form-data - contentType := req.Header.Get("Content-Type") - if contentType == "" { - req.Header.Set("Content-Type", "application/json; charset=utf-8") - } - - if len(c.idOrg) > 0 { - req.Header.Set("OpenAI-Organization", c.idOrg) - } - - res, err := c.HTTPClient.Do(req) - if err != nil { - return err - } - - defer res.Body.Close() - - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - var errRes ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errRes) - if err != nil || errRes.Error == nil { - return fmt.Errorf("error, status code: %d", res.StatusCode) - } - return fmt.Errorf("error, status code: %d, message: %s", res.StatusCode, errRes.Error.Message) - } - - if v != nil { - if err = json.NewDecoder(res.Body).Decode(&v); err != nil { - return err - } - } - - return nil -} - -func (c *Client) fullURL(suffix string) string { - return fmt.Sprintf("%s%s", c.BaseURL, suffix) -} diff --git a/client.go b/client.go new file mode 100644 index 0000000..b7fdf3b --- /dev/null +++ b/client.go @@ -0,0 +1,203 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "path" + + "github.com/fabiustech/openai/routes" +) + +const ( + scheme = "https" + host = "api.openai.com" + basePath = "v1" +) + +// Client is OpenAI API client. +type Client struct { + token string + orgID *string + + // scheme and host are only used for testing. + // TODO: Figure out a better approach. + scheme, host string +} + +// NewClient creates new OpenAI API client. +func NewClient(token string) *Client { + return &Client{ + token: token, + scheme: scheme, + host: host, + } +} + +// NewClientWithOrg creates new OpenAI API client for specified Organization ID. +func NewClientWithOrg(token, org string) *Client { + return &Client{ + token: token, + orgID: &org, + scheme: scheme, + host: host, + } +} + +func (c *Client) newRequest(ctx context.Context, method string, url string, body io.Reader) (*http.Request, error) { + var req, err = http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) + + if c.orgID != nil { + req.Header.Set("OpenAI-Organization", *c.orgID) + } + + return req, nil +} + +func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, error) { + var b, err = json.Marshal(payload) + if err != nil { + return nil, err + } + + var req *http.Request + req, err = c.newRequest(ctx, "POST", c.reqURL(path), bytes.NewBuffer(b)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) +} + +func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) { + var b bytes.Buffer + var w = multipart.NewWriter(&b) + + if err := w.WriteField("purposes", fr.Purpose); err != nil { + return nil, err + } + + var fw, err = w.CreateFormFile("file", fr.File.Name()) + if err != nil { + return nil, err + } + + if _, err = io.Copy(fw, fr.File); err != nil { + return nil, err + } + + if err = w.Close(); err != nil { + return nil, err + } + + var req *http.Request + req, err = c.newRequest(ctx, "POST", c.reqURL(routes.Files), &b) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", w.FormDataContentType()) + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) +} + +func (c *Client) get(ctx context.Context, path string) ([]byte, error) { + var req, err = c.newRequest(ctx, "GET", c.reqURL(path), nil) + if err != nil { + return nil, err + } + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) +} + +func (c *Client) delete(ctx context.Context, path string) ([]byte, error) { + var req, err = c.newRequest(ctx, "DELETE", c.reqURL(path), nil) + if err != nil { + return nil, err + } + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) +} + +func (c *Client) reqURL(route string) string { + var u = &url.URL{ + Scheme: c.scheme, + Host: c.host, + Path: path.Join(basePath, route), + } + + return u.String() +} + +func interpretResponse(resp *http.Response) error { + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + var b, err = io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, status code: %d", resp.StatusCode) + } + var er = &errorResponse{} + if err = json.Unmarshal(b, er); err != nil || er.Error == nil { + return fmt.Errorf("error, status code: %d, msg: %s", resp.StatusCode, string(b)) + } + + return er.Error + } + + return nil +} diff --git a/api_test.go b/client_test.go similarity index 66% rename from api_test.go rename to client_test.go index 1e1c5d0..0ac056e 100644 --- a/api_test.go +++ b/client_test.go @@ -1,37 +1,47 @@ -package gogpt_test +package openai import ( "bytes" "context" "encoding/json" "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" + "net/url" "os" "strconv" "strings" "testing" "time" - . "github.com/sashabaranov/go-gpt3" + "github.com/fabiustech/openai/images" + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/params" ) +/* +This test suite has been ported from the original repo: https://github.com/sashabaranov/go-gpt3. +It is incomplete, and it's usefulness is questionable. + +TODO: Cover all endpoints. +*/ + const ( - testAPIToken = "this-is-my-secure-token-do-not-steal!!" + testToken = "this-is-my-secure-token-do-not-steal!!" ) func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { + var token, ok = os.LookupEnv("OPENAI_TOKEN") + if !ok { t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") } - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) + var c = NewClient(token) + var ctx = context.Background() + var _, err = c.ListEngines(ctx) if err != nil { t.Fatalf("ListEngines error: %v", err) } @@ -41,49 +51,57 @@ func TestAPI(t *testing.T) { t.Fatalf("GetEngine error: %v", err) } - fileRes, err := c.ListFiles(ctx) + var fl *List[*File] + fl, err = c.ListFiles(ctx) if err != nil { t.Fatalf("ListFiles error: %v", err) } - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) + if len(fl.Data) > 0 { + _, err = c.RetrieveFile(ctx, fl.Data[0].ID) if err != nil { - t.Fatalf("GetFile error: %v", err) + t.Fatalf("RetrieveFile error: %v", err) } - } // else skip + } - embeddingReq := EmbeddingRequest{ + _, err = c.CreateEmbeddings(ctx, &EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) + Model: models.AdaEmbeddingV2, + }) if err != nil { t.Fatalf("Embedding error: %v", err) } } +func newTestClient(u string) (*Client, error) { + var h, err = url.Parse(u) + if err != nil { + return nil, err + } + + return &Client{ + token: testToken, + host: h.Host, + scheme: "http", + }, nil +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" + var client, _ = newTestClient(ts.URL) - req := CompletionRequest{ + var _, err = client.CreateCompletion(context.Background(), &CompletionRequest[models.Completion]{ + Prompt: "Lorem ipsum", + Model: models.TextDavinci003, MaxTokens: 5, - Model: "ada", - } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) + }) if err != nil { t.Fatalf("CreateCompletion error: %v", err) } @@ -91,57 +109,36 @@ func TestCompletions(t *testing.T) { // TestEdits Tests the edits endpoint of the API using the mocked server. func TestEdits(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" + var client, _ = newTestClient(ts.URL) - // create an edit request - model := "ada" - editReq := EditsRequest{ - Model: &model, + var n = 3 + var resp, err = client.CreateEdit(context.Background(), &EditsRequest{ + Model: models.TextDavinciEdit001, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + " ex ea commodo consequat. Duis aute irure dolor in reprehe", Instruction: "test instruction", - N: 3, - } - response, err := client.Edits(ctx, editReq) + N: n, + }) if err != nil { t.Fatalf("Edits error: %v", err) } - if len(response.Choices) != editReq.N { + if len(resp.Choices) != n { t.Fatalf("edits does not properly return the correct number of choices") } } func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []models.Embedding{ + models.AdaEmbeddingV2, } for _, model := range embeddedModels { - embeddingReq := EmbeddingRequest{ + embeddingReq := &EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -161,16 +158,16 @@ func TestEmbedding(t *testing.T) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (*EditsRequest, error) { + edit := &EditsRequest{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return nil, err } return edit, nil } @@ -184,15 +181,15 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq *EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ - Object: "test-object", + res := &EditsResponse{ + Object: objects.Edit, Created: uint64(time.Now().Unix()), } // edit and calculate token usage @@ -201,12 +198,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, &EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = &Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -224,14 +221,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq *CompletionRequest[models.Completion] if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := &CompletionResponse[models.Completion]{ ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", + Object: objects.TextCompletion, Created: uint64(time.Now().Unix()), // would be nice to validate Model during testing, but // this may not be possible with how much upkeep @@ -245,14 +242,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, &CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt) * completionReq.N completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ + res.Usage = &Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -263,72 +260,79 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // handleImageEndpoint Handles the images endpoint by the test server. func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests + // Images only accepts POST requests. if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { + var ir, err = getImageBody(r) + if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + + var resp = &ImageResponse{ Created: uint64(time.Now().Unix()), } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + + // Handle default values. + if ir.N == 0 { + ir.N = 1 + } + + for i := 0; i < ir.N; i++ { + var imageData = &ImageData{} + switch ir.ResponseFormat { + // Invalid is the go default value, and URL is the default API behavior. + case images.FormatURL, images.FormatInvalid: + imageData.URL = params.Optional("https://example.com/image.png") + case images.FormatB64JSON: // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" + imageData.B64JSON = params.Optional("e30K") default: http.Error(w, "invalid response format", http.StatusBadRequest) return } - res.Data = append(res.Data, imageData) + resp.Data = append(resp.Data, imageData) } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) + + var b, _ = json.Marshal(resp) + _, _ = w.Write(b) } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (*CompletionRequest[models.Completion], error) { + var completion = &CompletionRequest[models.Completion]{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return nil, err } return completion, nil } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (*CreateImageRequest, error) { + var image = &CreateImageRequest{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + var reqBody, err = io.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return nil, err } + return image, nil } // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer +// https://beta.com/tokenizer // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available) func numTokens(s string) int { @@ -336,19 +340,12 @@ func numTokens(s string) int { } func TestImages(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) + var client, _ = newTestClient(ts.URL) + var _, err = client.CreateImage(context.Background(), &CreateImageRequest{Prompt: "Lorem ipsum"}) if err != nil { t.Fatalf("CreateImage error: %v", err) } @@ -360,7 +357,7 @@ func OpenAITestServer() *httptest.Server { log.Printf("received request at path %q\n", r.URL.Path) // check auth - if r.Header.Get("Authorization") != "Bearer "+testAPIToken { + if r.Header.Get("Authorization") != "Bearer "+testToken { w.WriteHeader(http.StatusUnauthorized) return } @@ -375,7 +372,7 @@ func OpenAITestServer() *httptest.Server { return case "/v1/images/generations": handleImageEndpoint(w, r) - // TODO: implement the other endpoints + // TODO: Implement the other endpoints. default: // the endpoint doesn't exist http.Error(w, "the resource path doesn't exist", http.StatusNotFound) diff --git a/common.go b/common.go deleted file mode 100644 index 9fb0178..0000000 --- a/common.go +++ /dev/null @@ -1,9 +0,0 @@ -// common.go defines common types used throughout the OpenAI API. -package gogpt - -// Usage Represents the total token usage per request to OpenAI. -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} diff --git a/completion.go b/completion.go deleted file mode 100644 index 97601c3..0000000 --- a/completion.go +++ /dev/null @@ -1,108 +0,0 @@ -package gogpt - -import ( - "bytes" - "context" - "encoding/json" - "net/http" -) - -// GPT3 Defines the models provided by OpenAI to use when generating -// completions from OpenAI. -// GPT3 Models are designed for text-based tasks. For code-specific -// tasks, please refer to the Codex series of models. -const ( - GPT3TextDavinci003 = "text-davinci-003" - GPT3TextDavinci002 = "text-davinci-002" - GPT3TextCurie001 = "text-curie-001" - GPT3TextBabbage001 = "text-babbage-001" - GPT3TextAda001 = "text-ada-001" - GPT3TextDavinci001 = "text-davinci-001" - GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3CurieInstructBeta = "curie-instruct-beta" - GPT3Curie = "curie" - GPT3Ada = "ada" - GPT3Babbage = "babbage" -) - -// Codex Defines the models provided by OpenAI. -// These models are designed for code-specific tasks, and use -// a different tokenizer which optimizes for whitespace. -const ( - CodexCodeDavinci002 = "code-davinci-002" - CodexCodeCushman001 = "code-cushman-001" - CodexCodeDavinci001 = "code-davinci-001" -) - -// CompletionRequest represents a request structure for completion API. -type CompletionRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` -} - -// CompletionChoice represents one of possible completions. -type CompletionChoice struct { - Text string `json:"text"` - Index int `json:"index"` - FinishReason string `json:"finish_reason"` - LogProbs LogprobResult `json:"logprobs"` -} - -// LogprobResult represents logprob result of Choice. -type LogprobResult struct { - Tokens []string `json:"tokens"` - TokenLogprobs []float32 `json:"token_logprobs"` - TopLogprobs []map[string]float32 `json:"top_logprobs"` - TextOffset []int `json:"text_offset"` -} - -// CompletionResponse represents a response structure for completion API. -type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created uint64 `json:"created"` - Model string `json:"model"` - Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` -} - -// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well -// as, if requested, the probabilities over each alternative token at each position. -// -// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, -// and the server will use the model's parameters to generate the completion. -func (c *Client) CreateCompletion( - ctx context.Context, - request CompletionRequest, -) (response CompletionResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/completions" - req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return - } - - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return -} diff --git a/completions.go b/completions.go new file mode 100644 index 0000000..3b35d4e --- /dev/null +++ b/completions.go @@ -0,0 +1,148 @@ +package openai + +import ( + "context" + "encoding/json" + + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" +) + +// CompletionRequest contains all relevant fields for requests to the completions endpoint. +type CompletionRequest[T models.Completion | models.FineTunedModel] struct { + // Model specifies the ID of the model to use. + // See more here: https://beta.openai.com/docs/models/overview + Model T `json:"model"` + // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of + // tokens, or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during + // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + // Defaults to <|endoftext|>. + Prompt string `json:"prompt,omitempty"` + // Suffix specifies the suffix that comes after a completion of inserted text. + // Defaults to null. + Suffix string `json:"suffix,omitempty"` + // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt + // plus max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens + // (except for the newest models, which support 4096). + // Defaults to 16. + MaxTokens int `json:"max_tokens,omitempty"` + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try + // 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally + // recommends altering this or top_p but not both. + // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 + // Defaults to 1. + Temperature *float32 `json:"temperature,omitempty"` + // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers + // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. + // Defaults to 1. + TopP *float32 `json:"top_p,omitempty"` + // N specifies how many completions to generate for each prompt. + // Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully + // and ensure that you have reasonable settings for max_tokens and stop. + // Defaults to 1. + N int `json:"n,omitempty"` + // Steam specifies Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent + // events as they become available, with the stream terminated by a data: [DONE] message. + // Defaults to false. + Stream bool `json:"stream,omitempty"` + // LogProbs specifies to include the log probabilities on the logprobs most likely tokens, as well the chosen + // tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will + // always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response. + // The maximum value for logprobs is 5. + // Defaults to null. + LogProbs *int `json:"logprobs,omitempty"` + // Echo specifies to echo back the prompt in addition to the completion. + // Defaults to false. + Echo bool `json:"echo,omitempty"` + // Stop specifies up to 4 sequences where the API will stop generating further tokens. The returned text will not + // contain the stop sequence. + Stop []string `json:"stop,omitempty"` + // PresencePenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + // appear in the text so far, increasing the model's likelihood to talk about new topics. + // Defaults to 0. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + // FrequencyPenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on their + // existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + // Defaults to 0. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // Generates best_of completions server-side and returns the "best" (the one with the highest log probability per + // token). Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n + // specifies how many to return – best_of must be greater than n. Note: Because this parameter generates many + // completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings + // for max_tokens and stop. + // Defaults to 1. + BestOf *int `json:"best_of,omitempty"` + // LogitBias modifies the likelihood of specified tokens appearing in the completion. Accepts a json object that + // maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. + // Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + // vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like + // -100 or 100 should result in a ban or exclusive selection of the relevant token. + // As an example, you can pass {"50256": -100} to prevent the <|endoftext|> token from being generated. + // + // You can use this tokenizer tool to convert text to token IDs: + // https://beta.openai.com/tokenizer + // + // Defaults to null. + LogitBias map[string]int `json:"logit_bias,omitempty"` + // User is a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + // See more here: https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids + User string `json:"user,omitempty"` +} + +// CompletionChoice represents one of possible completions. +type CompletionChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + LogProbs *LogprobResult `json:"logprobs"` +} + +// LogprobResult represents logprob result of Choice. +type LogprobResult struct { + Tokens []string `json:"tokens"` + TokenLogprobs []float32 `json:"token_logprobs"` + TopLogprobs []map[string]float32 `json:"top_logprobs"` + TextOffset []int `json:"text_offset"` +} + +// CompletionResponse is the response from the completions endpoint. +type CompletionResponse[T models.Completion | models.FineTunedModel] struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Created uint64 `json:"created"` + Model T `json:"model"` + Choices []*CompletionChoice `json:"choices"` + Usage *Usage `json:"usage"` +} + +// CreateCompletion creates a completion for the provided prompt and parameters. +func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest[models.Completion]) (*CompletionResponse[models.Completion], error) { + var b, err = c.post(ctx, routes.Completions, cr) + if err != nil { + return nil, err + } + + var resp = &CompletionResponse[models.Completion]{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +// CreateFineTunedCompletion creates a completion for the provided prompt and parameters, using a fine-tuned model. +func (c *Client) CreateFineTunedCompletion(ctx context.Context, cr *CompletionRequest[models.FineTunedModel]) (*CompletionResponse[models.FineTunedModel], error) { + var b, err = c.post(ctx, routes.Completions, cr) + if err != nil { + return nil, err + } + + var resp = &CompletionResponse[models.FineTunedModel]{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3fa1a00 --- /dev/null +++ b/doc.go @@ -0,0 +1,3 @@ +// Package openai is a client library for interacting with the OpenAI API. +// It supports all non-deprecated endpoints (as well as the Engines endpoint). +package openai diff --git a/edits.go b/edits.go index 8101429..d573867 100644 --- a/edits.go +++ b/edits.go @@ -1,20 +1,36 @@ -package gogpt +package openai import ( - "bytes" "context" "encoding/json" - "net/http" + + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" ) -// EditsRequest represents a request structure for Edits API. +// EditsRequest contains all relevant fields for requests to the edits endpoint. type EditsRequest struct { - Model *string `json:"model,omitempty"` - Input string `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - N int `json:"n,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` + Model models.Edit `json:"model"` + // Input is the input text to use as a starting point for the edit. + // Defaults to "". + Input string `json:"input,omitempty"` + // Instruction is the instruction that tells the model how to edit the prompt. + Instruction string `json:"instruction,omitempty"` + // N specifies how many edits to generate for the input and instruction. + // Defaults to 1. + N int `json:"n,omitempty"` + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. + // Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI + // generally recommends altering this or top_p but not both. + // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 + // Defaults to 1. + Temperature *float32 `json:"temperature,omitempty"` + // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers + // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. + // Defaults to 1. + TopP *float32 `json:"top_p,omitempty"` } // EditsChoice represents one of possible edits. @@ -25,26 +41,23 @@ type EditsChoice struct { // EditsResponse represents a response structure for Edits API. type EditsResponse struct { - Object string `json:"object"` - Created uint64 `json:"created"` - Usage Usage `json:"usage"` - Choices []EditsChoice `json:"choices"` + Object objects.Object `json:"object"` // "edit" + Created uint64 `json:"created"` + Usage *Usage `json:"usage"` + Choices []*EditsChoice `json:"choices"` } -// Perform an API call to the Edits endpoint. -func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +// CreateEdit creates a new edit for the provided input, instruction, and parameters. +func (c *Client) CreateEdit(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { + var b, err = c.post(ctx, routes.Edits, er) if err != nil { - return + return nil, err } - req, err := http.NewRequest("POST", c.fullURL("/edits"), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp = &EditsResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + return resp, nil } diff --git a/embeddings.go b/embeddings.go index 52c5223..e7d095f 100644 --- a/embeddings.go +++ b/embeddings.go @@ -1,100 +1,14 @@ -package gogpt +package openai import ( - "bytes" "context" "encoding/json" - "net/http" -) - -// EmbeddingModel enumerates the models which can be used -// to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - return nil -} - -const ( - Unknown EmbeddingModel = iota - AdaSimilarity - BabbageSimilarity - CurieSimilarity - DavinciSimilarity - AdaSearchDocument - AdaSearchQuery - BabbageSearchDocument - BabbageSearchQuery - CurieSearchDocument - CurieSearchQuery - DavinciSearchDocument - DavinciSearchQuery - AdaCodeSearchCode - AdaCodeSearchText - BabbageCodeSearchCode - BabbageCodeSearchText - AdaEmbeddingV2 + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" ) -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -102,52 +16,40 @@ var stringToEnum = map[string]EmbeddingModel{ // between two inputs in the original format. For example, if two texts are similar, // then their vector representations should also be similar. type Embedding struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` + Object objects.Object `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` } // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { - Object string `json:"object"` - Data []Embedding `json:"data"` - Model EmbeddingModel `json:"model"` - Usage Usage `json:"usage"` + *List[*Embedding] + Model models.Embedding + Usage *Usage } -// EmbeddingRequest is the input to a Create embeddings request. +// EmbeddingRequest contains all relevant fields for requests to the embeddings endpoint. type EmbeddingRequest struct { - // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. - // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they - // have observed inferior results when newlines are present. - // E.g. - // "The food was delicious and the waiter..." + // Input represents input text to get embeddings for, encoded as a strings. To get embeddings for multiple inputs in + // a single request, pass a slice of length > 1. Each input string must not exceed 8192 tokens in length. Input []string `json:"input"` - // ID of the model to use. You can use the List models API to see all of your available models, - // or see our Model overview for descriptions of them. - Model EmbeddingModel `json:"model"` - // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + // Model is the ID of the model to use. + Model models.Embedding `json:"model"` + // User is a unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. -// https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +// CreateEmbeddings creates an embedding vector representing the input text. +func (c *Client) CreateEmbeddings(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { + var b, err = c.post(ctx, routes.Embeddings, request) if err != nil { - return + return nil, err } - urlSuffix := "/embeddings" - req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp = &EmbeddingResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &resp) - - return + return resp, nil } diff --git a/engines.go b/engines.go index 3f82e98..1f93cff 100644 --- a/engines.go +++ b/engines.go @@ -1,12 +1,14 @@ -package gogpt +package openai import ( "context" - "fmt" - "net/http" + "encoding/json" + "path" + + "github.com/fabiustech/openai/routes" ) -// Engine struct represents engine from OpenAPI API. +// Engine contains all relevant fields for requests to the engines endpoint. type Engine struct { ID string `json:"id"` Object string `json:"object"` @@ -14,37 +16,39 @@ type Engine struct { Ready bool `json:"ready"` } -// EnginesList is a list of engines. -type EnginesList struct { - Engines []Engine `json:"data"` -} - -// ListEngines Lists the currently available engines, and provides basic +// ListEngines lists the currently available engines, and provides basic // information about each option such as the owner and availability. -func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := http.NewRequest("GET", c.fullURL("/engines"), nil) +// +// Deprecated: Please use their replacement, Models, instead. +// https://beta.openai.com/docs/api-reference/models +func (c *Client) ListEngines(ctx context.Context) (*List[*Engine], error) { + var b, err = c.get(ctx, routes.Engines) if err != nil { - return + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &engines) - return + var el = &List[*Engine]{} + if err = json.Unmarshal(b, el); err != nil { + return nil, err + } + + return el, nil } -// GetEngine Retrieves an engine instance, providing basic information about -// the engine such as the owner and availability. -func (c *Client) GetEngine( - ctx context.Context, - engineID string, -) (engine Engine, err error) { - urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) +// GetEngine retrieves a model instance, providing basic information about it such as the owner and availability. +// +// Deprecated: Please use their replacement, Models, instead. +// https://beta.openai.com/docs/api-reference/models +func (c *Client) GetEngine(ctx context.Context, id string) (*Engine, error) { + var b, err = c.get(ctx, path.Join(routes.Engines, id)) if err != nil { - return + return nil, err + } + + var e = &Engine{} + if err = json.Unmarshal(b, e); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &engine) - return + return e, nil } diff --git a/error.go b/error.go index 4d0a324..86e77e7 100644 --- a/error.go +++ b/error.go @@ -1,10 +1,32 @@ -package gogpt - -type ErrorResponse struct { - Error *struct { - Code *int `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - } `json:"error,omitempty"` +package openai + +import ( + "fmt" + "net/http" +) + +// errorResponse wraps the returned error. +type errorResponse struct { + Error *Error `json:"error,omitempty"` +} + +// Error represents an error response from the API. +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +// Error implements the error interface. +func (e *Error) Error() string { + return fmt.Sprintf("Code: %v, Message: %s, Type: %s, Param: %v", e.Code, e.Message, e.Type, e.Param) +} + +// Retryable returns true if the error is retryable. +func (e *Error) Retryable() bool { + if e.Code >= http.StatusInternalServerError { + return true + } + return e.Code == http.StatusTooManyRequests } diff --git a/files.go b/files.go index 672f060..6dbf411 100644 --- a/files.go +++ b/files.go @@ -1,156 +1,98 @@ -package gogpt +package openai import ( - "bytes" "context" - "fmt" - "io" - "mime/multipart" - "net/http" - "net/url" + "encoding/json" "os" - "strings" + "path" + + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" ) +// FileRequest contains all relevant data for upload requests to the files endpoint. type FileRequest struct { - FileName string `json:"file"` - FilePath string `json:"-"` - Purpose string `json:"purpose"` -} - -// File struct represents an OpenAPI file. -type File struct { - Bytes int `json:"bytes"` - CreatedAt int `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + // File is the JSON Lines file to be uploaded. If the purpose is set to "fine-tune", each line is a JSON record + // with "prompt" and "completion" fields representing your training examples: + // https://beta.openai.com/docs/guides/fine-tuning/prepare-training-data. + File *os.File + // Purpose is the intended purpose of the uploaded documents. Use "fine-tune" for Fine-tuning. + // This allows OpenAI to validate the format of the uploaded file. + Purpose string } -// FilesList is a list of files that belong to the user or organization. -type FilesList struct { - Files []File `json:"data"` -} - -// isUrl is a helper function that determines whether the given FilePath -// is a remote URL or a local file path. -func isURL(path string) bool { - _, err := url.ParseRequestURI(path) +// NewFineTuneFileRequest returns a |*FileRequest| with File opened from |path| and Purpose set to "fine-tuned". +func NewFineTuneFileRequest(path string) (*FileRequest, error) { + var f, err = os.Open(path) if err != nil { - return false + return nil, err } - u, err := url.Parse(path) - if err != nil || u.Scheme == "" || u.Host == "" { - return false - } - - return true + return &FileRequest{ + File: f, + Purpose: "fine-tune", + }, nil } -// CreateFile uploads a jsonl file to GPT3 -// FilePath can be either a local file path or a URL. -func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { - var b bytes.Buffer - w := multipart.NewWriter(&b) +// File represents an OpenAPI file. +type File struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Bytes int `json:"bytes"` + CreatedAt int `json:"created_at"` + Filename string `json:"filename"` + Purpose string `json:"purpose"` +} - var fw, pw io.Writer - pw, err = w.CreateFormField("purpose") +// ListFiles returns a list of files that belong to the user's organization. +func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { + var b, err = c.get(ctx, routes.Files) if err != nil { - return + return nil, err } - _, err = io.Copy(pw, strings.NewReader(request.Purpose)) - if err != nil { - return + var fl = &List[*File]{} + if err = json.Unmarshal(b, fl); err != nil { + return nil, err } - fw, err = w.CreateFormFile("file", request.FileName) - if err != nil { - return - } - - var fileData io.ReadCloser - if isURL(request.FilePath) { - var remoteFile *http.Response - remoteFile, err = http.Get(request.FilePath) - if err != nil { - return - } - - defer remoteFile.Body.Close() - - // Check server response - if remoteFile.StatusCode != http.StatusOK { - err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) - return - } - - fileData = remoteFile.Body - } else { - fileData, err = os.Open(request.FilePath) - if err != nil { - return - } - } + return fl, nil +} - _, err = io.Copy(fw, fileData) +// UploadFile uploads a file that contains document(s) to be used across various endpoints/features. Currently, the size +// of all the files uploaded by one organization can be up to 1 GB. +func (c *Client) UploadFile(ctx context.Context, fr *FileRequest) (*File, error) { + var b, err = c.postFile(ctx, fr) if err != nil { - return + return nil, err } - w.Close() - - req, err := http.NewRequest("POST", c.fullURL("/files"), &b) - if err != nil { - return + var f = &File{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err } - req = req.WithContext(ctx) - req.Header.Set("Content-Type", w.FormDataContentType()) - - err = c.sendRequest(req, &file) - - return + return f, nil } -// DeleteFile deletes an existing file. -func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := http.NewRequest("DELETE", c.fullURL("/files/"+fileID), nil) - if err != nil { - return - } +// DeleteFile deletes a file. +func (c *Client) DeleteFile(ctx context.Context, id string) error { + var _, err = c.delete(ctx, path.Join(routes.Files, id)) - req = req.WithContext(ctx) - err = c.sendRequest(req, nil) - return + return err } -// ListFiles Lists the currently available files, -// and provides basic information about each file such as the file name and purpose. -func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := http.NewRequest("GET", c.fullURL("/files"), nil) +// RetrieveFile returns information about a specific file. +func (c *Client) RetrieveFile(ctx context.Context, id string) (*File, error) { + var b, err = c.get(ctx, path.Join(routes.Files, id)) if err != nil { - return + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &files) - return -} - -// GetFile Retrieves a file instance, providing basic information about the file -// such as the file name and purpose. -func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { - urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) - if err != nil { - return + var f = &File{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &file) - return + return f, nil } diff --git a/fine_tunes.go b/fine_tunes.go new file mode 100644 index 0000000..3b9e178 --- /dev/null +++ b/fine_tunes.go @@ -0,0 +1,228 @@ +package openai + +import ( + "context" + "encoding/json" + "path" + + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" +) + +// FineTuneRequest contains all relevant fields for requests to the fine-tunes endpoints. +type FineTuneRequest struct { + // TrainingFile specifies the ID of an uploaded file that contains training data. See upload file for how to upload + // a file. + // + // https://beta.openai.com/docs/api-reference/files/upload + // + // Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys + // "prompt" and "completion". Additionally, you must upload your file with the purpose fine-tune. See the + // fine-tuning guide for more details: + // + // https://beta.openai.com/docs/guides/fine-tuning/creating-training-data + TrainingFile string `json:"training_file"` + // ValidationFile specifies the ID of an uploaded file that contains validation data. If you provide this file, the + // data is used to generate validation metrics periodically during fine-tuning. These metrics can be viewed in the + // fine-tuning results file. + // + // https://beta.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model + // + // Your train and validation data should be mutually exclusive. Your dataset must be formatted as a JSONL file, + // where each validation example is a JSON object with the keys "prompt" and "completion". Additionally, you must + // upload your file with the purpose fine-tune. See the fine-tuning guide for more details: + // + // https://beta.openai.com/docs/guides/fine-tuning/creating-training-data + ValidationFile *string `json:"validation_file,omitempty"` + // Model specifies the name of the base model to fine-tune. You can select one of "ada", "babbage", "curie", + // "davinci", or a fine-tuned model created after 2022-04-21. To learn more about these models, see the Models + // documentation. + // Defaults to "curie". + Model *models.FineTune `json:"model,omitempty"` + // NEpochs specifies the number of epochs to train the model for. An epoch refers to one full cycle through + // the training dataset. + // Defaults to 4. + NEpochs *int `json:"n_epochs,omitempty"` + // BatchSize specifies the batch size to use for training. The batch size is the number of training examples used + // to train a single forward and backward pass. By default, the batch size will be dynamically configured to be + // ~0.2% of the number of examples in the training set, capped at 256 - in general, we've found that larger batch + // sizes tend to work better for larger datasets. + // Defaults to null. + BatchSize *int `json:"batch_size,omitempty"` + // LearningRateMultiplier specifies the learning rate multiplier to use for training. The fine-tuning learning rate + // is the original learning rate used for pretraining multiplied by this value. By default, the learning rate + // multiplier is the 0.05, 0.1, or 0.2 depending on final batch_size (larger learning rates tend to perform better + // with larger batch sizes). We recommend experimenting with values in the range 0.02 to 0.2 to see what produces + // the best results. + // Defaults to null. + LearningRateMultiplier *int `json:"learning_rate_multiplier,omitempty"` + // PromptLossWeight specifies the weight to use for loss on the prompt tokens. This controls how much the model + // tries to learn to generate the prompt (as compared to the completion which always has a weight of 1.0), and can + // add a stabilizing effect to training when completions are short. If prompts are extremely long (relative to + // completions), it may make sense to reduce this weight so as to avoid over-prioritizing learning the prompt. + // Defaults to 0.01. + PromptLossWeight *int `json:"prompt_loss_weight,omitempty"` + // ComputeClassificationMetrics calculates classification-specific metrics such as accuracy and F-1 score using the + // validation set at the end of every epoch if set to true. These metrics can be viewed in the results file. + // + // https://beta.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model + // + // In order to compute classification metrics, you must provide a ValidationFile. Additionally, you must specify + // ClassificationNClasses for multiclass classification or ClassificationPositiveClass for binary classification. + ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` + // ClassificationNClasses specifies the number of classes in a classification task. This parameter is required for + // multiclass classification. + // Defaults to null. + ClassificationNClasses *int `json:"classification_n_classes,omitempty"` + // ClassificationPositiveClass specifies the positive class in binary classification. This parameter is needed to + // generate precision, recall, and F1 metrics when doing binary classification. + // Defaults to null. + ClassificationPositiveClass *string `json:"classification_positive_class,omitempty"` + // ClassificationBetas specifies that if provided, we calculate F-beta scores at the specified beta values. The + // F-beta score is a generalization of F-1 score. This is only used for binary classification. With a beta of 1 + // (i.e. the F-1 score), precision and recall are given the same weight. A larger beta score puts more weight on + // recall and less on precision. A smaller beta score puts more weight on precision and less on recall. + // Defaults to null. + ClassificationBetas []float32 `json:"classification_betas,omitempty"` + // Suffix specifies a string of up to 40 characters that will be added to your fine-tuned model name. For example, + // a suffix of "custom-model-name" would produce a model name like + // ada:ft-your-org:custom-model-name-2022-02-15-04-21-04. + Suffix string `json:"suffix,omitempty"` +} + +// Event represents an event related to a fine-tune request. +type Event struct { + Object objects.Object `json:"object"` + CreatedAt uint64 `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` +} + +// FineTuneResponse is the response from fine-tunes endpoints. +type FineTuneResponse struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Model models.FineTune `json:"model"` + CreatedAt uint64 `json:"created_at"` + Events []*Event `json:"events,omitempty"` + FineTunedModel *models.FineTunedModel `json:"fine_tuned_model"` + Hyperparams struct { + BatchSize int `json:"batch_size"` + LearningRateMultiplier float64 `json:"learning_rate_multiplier"` + NEpochs int `json:"n_epochs"` + PromptLossWeight float64 `json:"prompt_loss_weight"` + } `json:"hyperparams"` + OrganizationID string `json:"organization_id"` + ResultFiles []string `json:"result_files"` + Status string `json:"status"` + ValidationFiles []string `json:"validation_files"` + TrainingFiles []struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Bytes int `json:"bytes"` + CreatedAt uint64 `json:"created_at"` + Filename string `json:"filename"` + Purpose string `json:"purpose"` + } `json:"training_files"` + UpdatedAt uint64 `json:"updated_at"` +} + +// FineTuneDeletionResponse is the response from the fine-tunes/delete endpoint. +type FineTuneDeletionResponse struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Deleted bool `json:"deleted"` +} + +// CreateFineTune creates a job that fine-tunes a specified model from a given dataset. *FineTuneResponse includes +// details of the enqueued job including job status and the name of the fine-tuned models once complete. +func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*FineTuneResponse, error) { + var b, err = c.post(ctx, routes.FineTunes, ftr) + if err != nil { + return nil, err + } + + var f = &FineTuneResponse{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} + +// ListFineTunes lists your organization's fine-tuning jobs. +func (c *Client) ListFineTunes(ctx context.Context) (*List[*FineTuneResponse], error) { + var b, err = c.get(ctx, routes.FineTunes) + if err != nil { + return nil, err + } + + var l = &List[*FineTuneResponse]{} + if err = json.Unmarshal(b, l); err != nil { + return nil, err + } + + return l, nil +} + +// RetrieveFineTune gets info about the fine-tune job. +func (c *Client) RetrieveFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { + var b, err = c.get(ctx, path.Join(routes.FineTunes, id)) + if err != nil { + return nil, err + } + + var f = &FineTuneResponse{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} + +// CancelFineTune immediately cancels a fine-tune job. +func (c *Client) CancelFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { + var b, err = c.post(ctx, path.Join(routes.FineTunes, id, "cancel"), nil) + if err != nil { + return nil, err + } + + var f = &FineTuneResponse{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} + +// ListFineTuneEvents returns fine-grained status updates for a fine-tune job. +// TODO: Support streaming (in a different method). +func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Event], error) { + var b, err = c.get(ctx, path.Join(routes.FineTunes, id, "events")) + if err != nil { + return nil, err + } + + var l = &List[*Event]{} + if err = json.Unmarshal(b, l); err != nil { + return nil, err + } + + return l, nil +} + +// DeleteFineTune delete a fine-tuned model. You must have the Owner role in your organization. +func (c *Client) DeleteFineTune(ctx context.Context, id string) (*FineTuneDeletionResponse, error) { + var b, err = c.delete(ctx, path.Join(routes.FineTunes, id)) + if err != nil { + return nil, err + } + + var f = &FineTuneDeletionResponse{} + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} diff --git a/go.mod b/go.mod index 4b6bb42..54ca0af 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/sashabaranov/go-gpt3 +module github.com/fabiustech/openai -go 1.17 +go 1.19 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/image.go b/image.go deleted file mode 100644 index 335e82f..0000000 --- a/image.go +++ /dev/null @@ -1,60 +0,0 @@ -package gogpt - -import ( - "bytes" - "context" - "encoding/json" - "net/http" -) - -// Image sizes defined by the OpenAI API. -const ( - CreateImageSize256x256 = "256x256" - CreateImageSize512x512 = "512x512" - CreateImageSize1024x1024 = "1024x1024" -) - -const ( - CreateImageResponseFormatURL = "url" - CreateImageResponseFormatB64JSON = "b64_json" -) - -// ImageRequest represents the request structure for the image API. -type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` -} - -// ImageResponse represents a response structure for image API. -type ImageResponse struct { - Created uint64 `json:"created,omitempty"` - Data []ImageResponseDataInner `json:"data,omitempty"` -} - -// ImageResponseData represents a response data structure for image API. -type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` -} - -// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. -func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/images/generations" - req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return - } - - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return -} diff --git a/images.go b/images.go new file mode 100644 index 0000000..aaa9232 --- /dev/null +++ b/images.go @@ -0,0 +1,133 @@ +package openai + +import ( + "context" + "encoding/json" + + "github.com/fabiustech/openai/images" + "github.com/fabiustech/openai/routes" +) + +// CreateImageRequest contains all relevant fields for requests to the images/generations endpoint. +type CreateImageRequest struct { + // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. + Prompt string `json:"prompt"` + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` +} + +// EditImageRequest contains all relevant fields for requests to the images/edits endpoint. +type EditImageRequest struct { + // Image is the image to edit. Must be a valid PNG file, less than 4MB, and square. If Mask is not provided, image + // must have transparency, which will be used as the mask. + Image string `json:"image"` + // Mask is an additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should + // be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as Image. + Mask string `json:"mask,omitempty"` + // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. + Prompt string `json:"prompt"` + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` +} + +// VariationImageRequest contains all relevant fields for requests to the images/variations endpoint. +type VariationImageRequest struct { + // Image is the image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + Image string `json:"image"` + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` +} + +// ImageResponse represents a response structure for image API. +type ImageResponse struct { + Created uint64 `json:"created,omitempty"` + Data []*ImageData `json:"data,omitempty"` +} + +// ImageData represents a response data structure for image API. +// Only one field will be non-nil. +type ImageData struct { + URL *string `json:"url,omitempty"` + B64JSON *string `json:"b64_json,omitempty"` +} + +// CreateImage creates an image (or images) given a prompt. +func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routes.ImageGenerations, ir) + if err != nil { + return nil, err + } + + var resp = &ImageResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +// EditImage creates an edited or extended image (or images) given an original image and a prompt. +func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routes.ImageEdits, eir) + if err != nil { + return nil, err + } + + var resp = &ImageResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +// ImageVariation creates a variation (or variations) of a given image. +func (c *Client) ImageVariation(ctx context.Context, vir *VariationImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routes.ImageVariations, vir) + if err != nil { + return nil, err + } + + var resp = &ImageResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} diff --git a/images/formats.go b/images/formats.go new file mode 100644 index 0000000..e590baa --- /dev/null +++ b/images/formats.go @@ -0,0 +1,50 @@ +// Package images contains the enum values which represent the various +// image formats and sizes returned by the OpenAI image endpoints. +package images + +// Format represents the enum values for the formats in which +// generated images are returned. +type Format int + +const ( + // FormatInvalid represents and invalid Format option. + FormatInvalid Format = iota + // FormatURL specifies that the API will return a url to the generated image. + // URLs will expire after an hour. + FormatURL + // FormatB64JSON specifies that the API will return the image as Base64 data. + FormatB64JSON +) + +// String implements the fmt.Stringer interface. +func (f Format) String() string { + return formatToString[f] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (f Format) MarshalText() ([]byte, error) { + return []byte(f.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (f *Format) UnmarshalText(b []byte) error { + if val, ok := stringToFormat[(string(b))]; ok { + *f = val + return nil + } + + *f = FormatInvalid + + return nil +} + +var formatToString = map[Format]string{ + FormatURL: "url", + FormatB64JSON: "b64_json", +} + +var stringToFormat = map[string]Format{ + "url": FormatURL, + "b64_json": FormatB64JSON, +} diff --git a/images/sizes.go b/images/sizes.go new file mode 100644 index 0000000..0e43ae7 --- /dev/null +++ b/images/sizes.go @@ -0,0 +1,54 @@ +package images + +// Size represents the enum values for the image sizes that +// you can generate. Smaller sizes are faster to generate. +type Size int + +const ( + // SizeInvalid represents and invalid Size option. + SizeInvalid Size = iota + // Size256x256 specifies that the API will return an image that is + // 256x256 pixels. + Size256x256 + // Size512x512 specifies that the API will return an image that is + // 512x512 pixels. + Size512x512 + // Size1024x1024 specifies that the API will return an image that is + // 1024x1024 pixels. + Size1024x1024 +) + +// String implements the fmt.Stringer interface. +func (s Size) String() string { + return imageToString[s] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (s Size) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (s *Size) UnmarshalText(b []byte) error { + if val, ok := stringToImage[(string(b))]; ok { + *s = val + return nil + } + + *s = SizeInvalid + + return nil +} + +var imageToString = map[Size]string{ + Size256x256: "256x256", + Size512x512: "512x512", + Size1024x1024: "1024x1024", +} + +var stringToImage = map[string]Size{ + "256x256": Size256x256, + "512x512": Size512x512, + "1024x1024": Size1024x1024, +} diff --git a/list.go b/list.go new file mode 100644 index 0000000..5184927 --- /dev/null +++ b/list.go @@ -0,0 +1,13 @@ +package openai + +import ( + "github.com/fabiustech/openai/objects" +) + +// List represents a generic form of list of objects returned from many get endpoints. +type List[T any] struct { + // Object specifies the object type (e.g. Model). + Object objects.Object `json:"object"` + // Data contains the list of objects. + Data []T `json:"data"` +} diff --git a/models/completions.go b/models/completions.go new file mode 100644 index 0000000..23f6d35 --- /dev/null +++ b/models/completions.go @@ -0,0 +1,133 @@ +// Package models contains the enum values which represent the various +// models used by all OpenAI endpoints. +package models + +// Completion represents all models available for use with the Completions endpoint. +type Completion int + +const ( + // UnknownCompletion represents and invalid Completion model. + UnknownCompletion Completion = iota + // TextDavinci003 is the most capable GPT-3 model. Can do any task the other models can do, + // often with higher quality, longer output and better instruction-following. Also supports + // inserting completions within text. + // + // Supports up to 4,000 tokens. Training data up to Jun 2021. + TextDavinci003 + // TextDavinci002 is an older version of the most capable GPT-3 model. Can do any task the + // other models can do, often with higher quality, longer output and better + // instruction-following. Also supports inserting completions within text. + // + // Supports up to 4,000 tokens. + // + // Deprecated: Use TextDavinci003 instead. + TextDavinci002 + // TextCurie001 is very capable, but faster and lower cost than Davinci. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. + TextCurie001 + // TextBabbage001 is capable of straightforward tasks, very fast, and lower cost. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. + TextBabbage001 + // TextAda001 is capable of very simple tasks, usually the fastest model in the + // GPT-3 series, and lowest cost. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. + TextAda001 + // TextDavinci001 ... (?). + TextDavinci001 + + // DavinciInstructBeta is the most capable model in the InstructGPT series. + // It is much better at following user intentions than GPT-3 while also being + // more truthful and less toxic. InstructGPT is better than GPT-3 at following + // English instructions. + DavinciInstructBeta + // CurieInstructBeta is very capable, but faster and lower cost than Davinci. + // It is much better at following user intentions than GPT-3 while also being + // more truthful and less toxic. InstructGPT is better than GPT-3 at following + // English instructions. + CurieInstructBeta + + // CodeDavinci002 is the most capable Codex model. Particularly good at + // translating natural language to code. In addition to completing code, + // also supports inserting completions within code. + // + // Supports up to 8,000 tokens. Training data up to Jun 2021. + CodeDavinci002 + // CodeCushman001 is almost as capable as Davinci Codex, but slightly faster. + // This speed advantage may make it preferable for real-time applications. + // + // Supports up to 2,048 tokens. + CodeCushman001 + // CodeDavinci001 is and older version of the most capable Codex model. + // Particularly good at translating natural language to code. In addition + // to completing code, also supports inserting completions within code. + // + // Deprecated: Use CodeDavinci002 instead. + CodeDavinci001 + + // TextDavinciInsert002 was a beta model released for insertion. + // + // Deprecated: Insertion should be done via the text models. + TextDavinciInsert002 + // TextDavinciInsert001 was a beta model released for insertion. + // + // Deprecated: Insertion should be done via the text models. + TextDavinciInsert001 +) + +// String implements the fmt.Stringer interface. +func (c Completion) String() string { + return completionToString[c] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (c Completion) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (c *Completion) UnmarshalText(b []byte) error { + if val, ok := stringToCompletion[(string(b))]; ok { + *c = val + return nil + } + + *c = UnknownCompletion + + return nil +} + +var completionToString = map[Completion]string{ + TextDavinci003: "text-davinci-003", + TextDavinci002: "text-davinci-002", + TextCurie001: "text-curie-001", + TextBabbage001: "text-babbage-001", + TextAda001: "text-ada-001", + TextDavinci001: "text-davinci-001", + DavinciInstructBeta: "davinci-instruct-beta", + CurieInstructBeta: "curie-instruct-beta", + CodeDavinci002: "code-davinci-002", + CodeCushman001: "code-cushman-001", + CodeDavinci001: "code-davinci-001", + TextDavinciInsert002: "text-davinci-insert-002", + TextDavinciInsert001: "text-davinci-insert-001", +} + +var stringToCompletion = map[string]Completion{ + "text-davinci-003": TextDavinci003, + "text-davinci-002": TextDavinci002, + "text-curie-001": TextCurie001, + "text-babbage-001": TextBabbage001, + "text-ada-001": TextAda001, + "text-davinci-001": TextDavinci001, + "davinci-instruct-beta": DavinciInstructBeta, + "curie-instruct-beta": CurieInstructBeta, + "code-davinci-002": CodeDavinci002, + "code-cushman-001": CodeCushman001, + "code-davinci-001": CodeDavinci001, + "text-davinci-insert-002": TextDavinciInsert002, + "text-davinci-insert-001": TextDavinciInsert001, +} diff --git a/models/edits.go b/models/edits.go new file mode 100644 index 0000000..d01474f --- /dev/null +++ b/models/edits.go @@ -0,0 +1,48 @@ +package models + +// Edit represents all models available for use with the Edits endpoint. +type Edit int + +const ( + // UnknownEdit represents and invalid Edit model. + UnknownEdit Edit = iota + // TextDavinciEdit001 ... + TextDavinciEdit001 + // CodeDavinciEdit001 ... + CodeDavinciEdit001 +) + +// String implements the fmt.Stringer interface. +func (e Edit) String() string { + return editToString[e] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (e Edit) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (e *Edit) UnmarshalText(b []byte) error { + if val, ok := stringToEdit[(string(b))]; ok { + *e = val + return nil + } + + *e = UnknownEdit + + return nil +} + +var editToString = map[Edit]string{ + // TextDavinciEdit001 can be used to edit text, rather than just completing it. + TextDavinciEdit001: "text-davinci-edit-001", + // CodeDavinciEdit001 can be used to edit code, rather than just completing it. + CodeDavinciEdit001: "code-davinci-edit-001", +} + +var stringToEdit = map[string]Edit{ + "text-davinci-edit-001": TextDavinciEdit001, + "code-davinci-edit-001": CodeDavinciEdit001, +} diff --git a/models/embeddings.go b/models/embeddings.go new file mode 100644 index 0000000..569ddc0 --- /dev/null +++ b/models/embeddings.go @@ -0,0 +1,118 @@ +package models + +// Embedding enumerates the models which can be used +// to generate Embedding vectors. +type Embedding int + +const ( + // Unknown represents an invalid Embedding model. + Unknown Embedding = iota + + // AdaEmbeddingV2 is the second-generation embedding model. OpenAI recommends using + // text-embedding-ada-002 for nearly all use cases. It’s better, cheaper, and simpler to use. + // + // Supports up to 8191. Knowledge cutoff Sep 2021. + AdaEmbeddingV2 + + // The below models are first-generation models (those ending in -001) use the GPT-3 + // tokenizer and have a max input of 2046 tokens. First-generation embeddings are generated + // by five different model families tuned for three different tasks: text search, text similarity + // and code search. The search models come in pairs: one for short queries and one for long documents. + // Each family includes up to four models on a spectrum of quality and speed. + + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + AdaSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + BabbageSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + CurieSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + DavinciSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + AdaSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + AdaSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + BabbageSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + BabbageSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + CurieSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + CurieSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + DavinciSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + DavinciSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + AdaCodeSearchCode + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + AdaCodeSearchText + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + BabbageCodeSearchCode + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. + BabbageCodeSearchText +) + +// String implements the fmt.Stringer interface. +func (e Embedding) String() string { + return enumToString[e] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (e Embedding) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (e *Embedding) UnmarshalText(b []byte) error { + if val, ok := stringToEnum[(string(b))]; ok { + *e = val + return nil + } + + *e = Unknown + + return nil +} + +var enumToString = map[Embedding]string{ + AdaSimilarity: "text-similarity-ada-001", + BabbageSimilarity: "text-similarity-babbage-001", + CurieSimilarity: "text-similarity-curie-001", + DavinciSimilarity: "text-similarity-davinci-001", + AdaSearchDocument: "text-search-ada-doc-001", + AdaSearchQuery: "text-search-ada-query-001", + BabbageSearchDocument: "text-search-babbage-doc-001", + BabbageSearchQuery: "text-search-babbage-query-001", + CurieSearchDocument: "text-search-curie-doc-001", + CurieSearchQuery: "text-search-curie-query-001", + DavinciSearchDocument: "text-search-davinci-doc-001", + DavinciSearchQuery: "text-search-davinci-query-001", + AdaCodeSearchCode: "code-search-ada-code-001", + AdaCodeSearchText: "code-search-ada-text-001", + BabbageCodeSearchCode: "code-search-babbage-code-001", + BabbageCodeSearchText: "code-search-babbage-text-001", + AdaEmbeddingV2: "text-embedding-ada-002", +} + +var stringToEnum = map[string]Embedding{ + "text-similarity-ada-001": AdaSimilarity, + "text-similarity-babbage-001": BabbageSimilarity, + "text-similarity-curie-001": CurieSimilarity, + "text-similarity-davinci-001": DavinciSimilarity, + "text-search-ada-doc-001": AdaSearchDocument, + "text-search-ada-query-001": AdaSearchQuery, + "text-search-babbage-doc-001": BabbageSearchDocument, + "text-search-babbage-query-001": BabbageSearchQuery, + "text-search-curie-doc-001": CurieSearchDocument, + "text-search-curie-query-001": CurieSearchQuery, + "text-search-davinci-doc-001": DavinciSearchDocument, + "text-search-davinci-query-001": DavinciSearchQuery, + "code-search-ada-code-001": AdaCodeSearchCode, + "code-search-ada-text-001": AdaCodeSearchText, + "code-search-babbage-code-001": BabbageCodeSearchCode, + "code-search-babbage-text-001": BabbageCodeSearchText, + "text-embedding-ada-002": AdaEmbeddingV2, +} diff --git a/models/fine_tunes.go b/models/fine_tunes.go new file mode 100644 index 0000000..409af8a --- /dev/null +++ b/models/fine_tunes.go @@ -0,0 +1,68 @@ +package models + +type FineTune int + +const ( + UnknownFineTune FineTune = iota + // Davinci most capable of the older versions of the GPT-3 models + // and is intended to be used with the fine-tuning endpoints. + Davinci + // Curie is very capable, but faster and lower cost than Davinci. It is + // an older version of the GPT-3 models and is intended to be used with + // the fine-tuning endpoints. + Curie + // Babbage is capable of straightforward tasks, very fast, and lower cost. + // It is an older version of the GPT-3 models and is intended to be used + // with the fine-tuning endpoints. + Babbage + // Ada is capable of very simple tasks, usually the fastest model in the + // GPT-3 series, and lowest cost. It is an older version of the GPT-3 + // models and is intended to be used with the fine-tuning endpoints. + Ada +) + +// String implements the fmt.Stringer interface. +func (f FineTune) String() string { + return fineTuneToString[f] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (f FineTune) MarshalText() ([]byte, error) { + return []byte(f.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (f *FineTune) UnmarshalText(b []byte) error { + if val, ok := stringToFineTune[(string(b))]; ok { + *f = val + return nil + } + + *f = UnknownFineTune + + return nil +} + +var fineTuneToString = map[FineTune]string{ + Davinci: "davinci", + Curie: "curie", + Ada: "ada", + Babbage: "babbage", +} + +var stringToFineTune = map[string]FineTune{ + "davinci": Davinci, + "curie": Curie, + "ada": Ada, + "babbage": Babbage, +} + +// FineTunedModel represents the name of a fine-tuned model which was +// previously generated. +type FineTunedModel string + +// NewFineTunedModel converts a string to FineTunedModel. +func NewFineTunedModel(name string) FineTunedModel { + return FineTunedModel(name) +} diff --git a/models/moderations.go b/models/moderations.go new file mode 100644 index 0000000..0d6ff33 --- /dev/null +++ b/models/moderations.go @@ -0,0 +1,48 @@ +package models + +// Moderation represents all models available for use with the CreateModeration endpoint. +type Moderation int + +const ( + // UnknownModeration represents and invalid Moderation model. + UnknownModeration Moderation = iota + // TextModerationStable ... + TextModerationStable + // TextModerationLatest ... + TextModerationLatest +) + +// String implements the fmt.Stringer interface. +func (m Moderation) String() string { + return moderationToString[m] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (m Moderation) MarshalText() ([]byte, error) { + return []byte(m.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (m *Moderation) UnmarshalText(b []byte) error { + if val, ok := stringToModeration[(string(b))]; ok { + *m = val + return nil + } + + *m = UnknownModeration + + return nil +} + +var moderationToString = map[Moderation]string{ + // TextDavinciEdit001 can be used to edit text, rather than just completing it. + TextModerationStable: "text-moderation-stable", + // CodeDavinciEdit001 can be used to edit code, rather than just completing it. + TextModerationLatest: "text-moderation-latest", +} + +var stringToModeration = map[string]Moderation{ + "text-moderation-stable": TextModerationStable, + "text-moderation-latest": TextModerationLatest, +} diff --git a/moderation.go b/moderation.go index 1058693..89285ea 100644 --- a/moderation.go +++ b/moderation.go @@ -1,23 +1,28 @@ -package gogpt +package openai import ( - "bytes" "context" "encoding/json" - "net/http" + + "github.com/fabiustech/openai/models" + + "github.com/fabiustech/openai/routes" ) -// ModerationRequest represents a request structure for moderation API. +// ModerationRequest contains all relevant fields for requests to the moderations endpoint. type ModerationRequest struct { - Input string `json:"input,omitempty"` - Model *string `json:"model,omitempty"` + // Input is the input text to classify. + Input string `json:"input,omitempty"` + // Model specifies the model to use for moderation. + // Defaults to models.TextModerationLatest. + Model models.Moderation `json:"model,omitempty"` } // Result represents one of possible moderation results. type Result struct { - Categories ResultCategories `json:"categories"` - CategoryScores ResultCategoryScores `json:"category_scores"` - Flagged bool `json:"flagged"` + Categories *ResultCategories `json:"categories"` + CategoryScores *ResultCategoryScores `json:"category_scores"` + Flagged bool `json:"flagged"` } // ResultCategories represents Categories of Result. @@ -49,21 +54,17 @@ type ModerationResponse struct { Results []Result `json:"results"` } -// Moderations — perform a moderation api call over a string. -// Input can be an array or slice but a string will reduce the complexity. -func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +// CreateModeration classifies if text violates OpenAI's Content Policy. +func (c *Client) CreateModeration(ctx context.Context, mr *ModerationRequest) (*ModerationResponse, error) { + var b, err = c.post(ctx, routes.Moderations, mr) if err != nil { - return + return nil, err } - req, err := http.NewRequest("POST", c.fullURL("/moderations"), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp = &ModerationResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + return resp, nil } diff --git a/objects/objects.go b/objects/objects.go new file mode 100644 index 0000000..e4d463e --- /dev/null +++ b/objects/objects.go @@ -0,0 +1,80 @@ +// Package objects contains the enum values which represent the various +// objects returned by all OpenAI endpoints. +package objects + +// Object enumerates the various object types returned by OpenAI endpoints. +type Object int + +const ( + // Unknown is an invalid object. + Unknown Object = iota + // Model is a model (can be either a base model or fine-tuned). + Model + // List is a list of other objects. + List + // TextCompletion is a text completion. + TextCompletion + // CodeCompletion is a code completion. + CodeCompletion + // Edit is an edit. + Edit + // Embedding is an embedding. + Embedding + // File is a file. + File + // FineTune is a fine-tuned model. + FineTune + FineTimeEvent + // Engine represents an engine. + // Deprecated: use Model instead. + Engine +) + +// String implements the fmt.Stringer interface. +func (o Object) String() string { + return objectToString[o] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (o Object) MarshalText() ([]byte, error) { + return []byte(o.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (o *Object) UnmarshalText(b []byte) error { + if val, ok := stringToObject[(string(b))]; ok { + *o = val + return nil + } + + *o = Unknown + + return nil +} + +var objectToString = map[Object]string{ + Model: "model", + List: "list", + TextCompletion: "text_completion", + CodeCompletion: "code_completion", + Edit: "edit", + Embedding: "embedding", + File: "file", + FineTune: "fine-tune", + FineTimeEvent: "fine-tune-event", + Engine: "engine", +} + +var stringToObject = map[string]Object{ + "model": Model, + "list": List, + "text_completion": TextCompletion, + "code_completion": CodeCompletion, + "edit": Edit, + "embedding": Embedding, + "file": File, + "fine-tune": FineTune, + "fine-tune-event": FineTimeEvent, + "engine": Engine, +} diff --git a/params/params.go b/params/params.go new file mode 100644 index 0000000..5555311 --- /dev/null +++ b/params/params.go @@ -0,0 +1,8 @@ +// Package params provides a helper function to simplify setting optional +// parameters in struct literals. +package params + +// Optional returns a pointer to |v|. +func Optional[T any](v T) *T { + return &v +} diff --git a/routes/routes.go b/routes/routes.go new file mode 100644 index 0000000..79ee435 --- /dev/null +++ b/routes/routes.go @@ -0,0 +1,43 @@ +// Package routes contains constants for all OpenAI endpoint routes. +package routes + +const ( + // Completions is the route for the completions endpoint. + // https://beta.openai.com/docs/api-reference/completions + Completions = "completions" + // Edits is the route for the edits endpoint. + // https://beta.openai.com/docs/api-reference/edits + Edits = "edits" + // Embeddings is the route for the embeddings endpoint. + // https://beta.openai.com/docs/api-reference/embeddings + Embeddings = "embeddings" + + // Engines is the route for the engines endpoint. + // https://beta.openai.com/docs/api-reference/engines + // Deprecated: Use Models instead. + Engines = "engines" + + // Files is the route for the files endpoint. + // https://beta.openai.com/docs/api-reference/files + Files = "files" + + // FineTunes is the route for the fine-tunes endpoint. + // https://beta.openai.com/docs/api-reference/fine-tunes + FineTunes = "fines-tunes" + + imagesBase = "images/" + + // ImageGenerations is the route for the create images endpoint. + // https://beta.openai.com/docs/api-reference/images/create + ImageGenerations = imagesBase + "generations" + // ImageEdits is the route for the create image edits endpoint. + // https://beta.openai.com/docs/api-reference/images/create-edit + ImageEdits = imagesBase + "edits" + // ImageVariations is the route for the create image variations endpoint. + // https://beta.openai.com/docs/api-reference/images/create-variation + ImageVariations = imagesBase + "variations" + + // Moderations is the route for the moderations endpoint. + // https://beta.openai.com/docs/api-reference/moderations + Moderations = "moderations" +) diff --git a/usage.go b/usage.go new file mode 100644 index 0000000..b49a141 --- /dev/null +++ b/usage.go @@ -0,0 +1,12 @@ +package openai + +// Usage Represents the total token usage per request to OpenAI. +type Usage struct { + // PromptTokens is the number of tokens in the request's prompt. + PromptTokens int `json:"prompt_tokens"` + // CompletionTokens is the number of tokens in the completion response. + // Will not be set for requests to the embeddings endpoint. + CompletionTokens int `json:"completion_tokens,omitempty"` + // Total tokens is the sum of PromptTokens and CompletionTokens. + TotalTokens int `json:"total_tokens"` +}