diff --git a/chatcompletion_test.go b/chatcompletion_test.go index 1fb57aa..34f12b8 100644 --- a/chatcompletion_test.go +++ b/chatcompletion_test.go @@ -5,7 +5,10 @@ package openai_test import ( "context" "errors" + "net" + "net/http" "os" + "strings" "testing" "github.com/openai/openai-go" @@ -122,6 +125,54 @@ func TestChatCompletionGet(t *testing.T) { } } +func TestChatCompletionCustomBaseURL(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + srv := &http.Server{} + + ready := make(chan struct{}) + go func() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.String(), "/openai/v1") { + t.Errorf("expected prefix to be /openai/v1, got %s", r.URL.String()) + } + + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id": "completion_id"}`)) + }) + lstr, err := net.Listen("tcp", "localhost:4011") + if err != nil { + t.Errorf("net.Listen: %s", err.Error()) + } + close(ready) + if err := srv.Serve(lstr); err != http.ErrServerClosed { + t.Errorf("srv.Serve: %s", err.Error()) + } + }() + // Wait until the server is listening + <-ready + + go func() { + <-ctx.Done() + srv.Shutdown(ctx) + }() + + baseURL := "http://localhost:4011/openai/v1" + client := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("My API Key"), + ) + _, err := client.Chat.Completions.Get(context.TODO(), "completion_id") + if err != nil { + var apierr *openai.Error + if errors.As(err, &apierr) { + t.Log(string(apierr.DumpRequest(true))) + } + t.Fatalf("err should be nil: %s", err.Error()) + } +} + func TestChatCompletionUpdate(t *testing.T) { baseURL := "http://localhost:4010" if envURL, ok := os.LookupEnv("TEST_API_BASE_URL"); ok { diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index a859053..b97f982 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -353,7 +353,15 @@ func (cfg *RequestConfig) Execute() (err error) { return fmt.Errorf("requestconfig: base url is not set") } - cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/")) + effectiveURL, err := url.JoinPath(cfg.BaseURL.String(), strings.TrimLeft(cfg.Request.URL.Path, "/")) + if err != nil { + return err + } + effectiveURL, err = url.PathUnescape(effectiveURL) + if err != nil { + return err + } + cfg.Request.URL, err = url.Parse(effectiveURL) if err != nil { return err }