diff --git a/openaicompat/openaicompat.go b/openaicompat/openaicompat.go index e3fa86afc89c6c9d2944e675ab40ef09dc6c4759..2e1af1b4a17ba3842aeb1fe3dd05c222b09fb773 100644 --- a/openaicompat/openaicompat.go +++ b/openaicompat/openaicompat.go @@ -17,11 +17,10 @@ const ( type Option = func(*options) -func New(url string, opts ...Option) ai.Provider { +func New(opts ...Option) ai.Provider { providerOptions := options{ openaiOptions: []openai.Option{ openai.WithName(Name), - openai.WithBaseURL(url), }, languageModelOptions: []openai.LanguageModelOption{ openai.WithLanguageModelPrepareCallFunc(languagePrepareModelCall), @@ -37,6 +36,12 @@ func New(url string, opts ...Option) ai.Provider { return openai.New(providerOptions.openaiOptions...) } +func WithBaseURL(url string) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithBaseURL(url)) + } +} + func WithAPIKey(apiKey string) Option { return func(o *options) { o.openaiOptions = append(o.openaiOptions, openai.WithAPIKey(apiKey)) diff --git a/providertests/openaicompat_test.go b/providertests/openaicompat_test.go index b0837afa56e49854be346619f4e1f1bcd66c3e25..5d954ab75d285bcc0802bd734528c670e7d8843d 100644 --- a/providertests/openaicompat_test.go +++ b/providertests/openaicompat_test.go @@ -44,12 +44,12 @@ func testOpenAICompatThinking(t *testing.T, result *ai.AgentResult) { } } } - require.Greater(t, reasoningContentCount, 0) + require.Greater(t, reasoningContentCount, 0, "expected reasoning content, got none") } func builderXAIGrokCodeFast(r *recorder.Recorder) (ai.LanguageModel, error) { provider := openaicompat.New( - "https://api.x.ai/v1", + openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), ) @@ -58,7 +58,7 @@ func builderXAIGrokCodeFast(r *recorder.Recorder) (ai.LanguageModel, error) { func builderXAIGrok4Fast(r *recorder.Recorder) (ai.LanguageModel, error) { provider := openaicompat.New( - "https://api.x.ai/v1", + openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), ) @@ -67,7 +67,7 @@ func builderXAIGrok4Fast(r *recorder.Recorder) (ai.LanguageModel, error) { func builderXAIGrok3Mini(r *recorder.Recorder) (ai.LanguageModel, error) { provider := openaicompat.New( - "https://api.x.ai/v1", + openaicompat.WithBaseURL("https://api.x.ai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_XAI_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), ) @@ -76,7 +76,7 @@ func builderXAIGrok3Mini(r *recorder.Recorder) (ai.LanguageModel, error) { func builderZAIGLM45(r *recorder.Recorder) (ai.LanguageModel, error) { provider := openaicompat.New( - "https://api.z.ai/api/coding/paas/v4", + openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"), openaicompat.WithAPIKey(os.Getenv("FANTASY_ZAI_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), ) @@ -85,7 +85,7 @@ func builderZAIGLM45(r *recorder.Recorder) (ai.LanguageModel, error) { func builderGroq(r *recorder.Recorder) (ai.LanguageModel, error) { provider := openaicompat.New( - "https://api.groq.com/openai/v1", + openaicompat.WithBaseURL("https://api.groq.com/openai/v1"), openaicompat.WithAPIKey(os.Getenv("FANTASY_GROQ_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), )