-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdalle.go
40 lines (32 loc) · 848 Bytes
/
dalle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package imagegeneration
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
providers "github.com/polyfire/api/llm/providers"
openai "github.com/sashabaranov/go-openai"
)
func DALLEGenerate(ctx context.Context, prompt string, model string) (io.Reader, error) {
client := providers.NewOpenAIStreamProvider(ctx, model).Client
req := openai.ImageRequest{
Prompt: prompt,
Model: model,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
N: 1,
}
respBase64, err := client.CreateImage(ctx, req)
if err != nil {
fmt.Println(err)
return nil, err
}
imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
if err != nil {
fmt.Println(err)
return nil, err
}
r := bytes.NewReader(imgBytes)
return r, nil
}