From 2e53c78b300b7f67013af197a50cb4cc4f060d0d Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Fri, 29 Aug 2025 16:59:46 -0300 Subject: [PATCH] refactor: move each provider into its own package Reasoning for this is: 1. Users can import only those they want, and the Go compiler won't compile the external library the user don't need. 2. This simplify the API and makes it follow the Go conventions better: * `ai.NewOpenAiProvider` -> `openai.New` * `ai.WithOpenAiAPIKey` -> `openai.WithAPIKey` * etc. --- examples/agent/main.go | 6 +- examples/simple/main.go | 4 +- examples/stream/main.go | 4 +- examples/streaming-agent-simple/main.go | 6 +- examples/streaming-agent/main.go | 4 +- providers/{ => anthropic}/anthropic.go | 112 ++++----- providers/{ => openai}/openai.go | 86 +++---- providers/{ => openai}/openai_test.go | 296 ++++++++++++------------ 8 files changed, 259 insertions(+), 259 deletions(-) rename providers/{ => anthropic}/anthropic.go (87%) rename providers/{ => openai}/openai.go (94%) rename providers/{ => openai}/openai_test.go (92%) diff --git a/examples/agent/main.go b/examples/agent/main.go index e519e6cb6fc58aca33b4bc0f7f2971a69f63d1c5..97364783f632a5f8623abd3e550142aa70c3f8fe 100644 --- a/examples/agent/main.go +++ b/examples/agent/main.go @@ -6,12 +6,12 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { - provider := providers.NewOpenAiProvider( - providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")), + provider := openai.New( + openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")), ) model, err := provider.LanguageModel("gpt-4o") if err != nil { diff --git a/examples/simple/main.go b/examples/simple/main.go index c9fcfd56cd54efbeb42deb8ddeb85af772364c5b..5aea35047fda21a9266c6f17b613b9d75ba90ccc 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -6,11 +6,11 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/anthropic" ) func main() { - provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) + provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) model, err := provider.LanguageModel("claude-sonnet-4-20250514") if err != nil { fmt.Println(err) diff --git a/examples/stream/main.go b/examples/stream/main.go index 45583590ba6e8376d35882b4e747ed31f4b1a6b4..1eb3c9ee18f31f03229a34ed194cd703c899f2ef 100644 --- a/examples/stream/main.go +++ b/examples/stream/main.go @@ -7,11 +7,11 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { - provider := providers.NewOpenAiProvider(providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY"))) + provider := openai.New(openai.WithAPIKey(os.Getenv("OPENAI_API_KEY"))) model, err := provider.LanguageModel("gpt-4o") if err != nil { fmt.Println(err) diff --git a/examples/streaming-agent-simple/main.go b/examples/streaming-agent-simple/main.go index 8184f8951c90b2d4cc05590ecd27d71d6bccbec0..5f94d0a30cdfd518de1afbb05b79448ab34e2842 100644 --- a/examples/streaming-agent-simple/main.go +++ b/examples/streaming-agent-simple/main.go @@ -6,7 +6,7 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { @@ -18,8 +18,8 @@ func main() { } // Create provider and model - provider := providers.NewOpenAiProvider( - providers.WithOpenAiAPIKey(apiKey), + provider := openai.New( + openai.WithAPIKey(apiKey), ) model, err := provider.LanguageModel("gpt-4o-mini") if err != nil { diff --git a/examples/streaming-agent/main.go b/examples/streaming-agent/main.go index 988d23d512707f5025eaab63856217fcc8efa558..fb4fa400599e3814bd8eccf117467b0bebcaeb32 100644 --- a/examples/streaming-agent/main.go +++ b/examples/streaming-agent/main.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/anthropic" ) func main() { @@ -24,7 +24,7 @@ func main() { fmt.Println() // Create OpenAI provider and model - provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) + provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) model, err := provider.LanguageModel("claude-sonnet-4-20250514") if err != nil { fmt.Println(err) diff --git a/providers/anthropic.go b/providers/anthropic/anthropic.go similarity index 87% rename from providers/anthropic.go rename to providers/anthropic/anthropic.go index a8924a09db13db7e65fa0c35fd695992c1d39a3a..fbc1381f20b6d8a192bea1d1e2411234e9fe4439 100644 --- a/providers/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -1,4 +1,4 @@ -package providers +package anthropic import ( "context" @@ -16,31 +16,31 @@ import ( "github.com/charmbracelet/ai" ) -type AnthropicProviderOptions struct { - SendReasoning *bool `json:"send_reasoning,omitempty"` - Thinking *AnthropicThinkingProviderOption `json:"thinking,omitempty"` - DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` +type ProviderOptions struct { + SendReasoning *bool `json:"send_reasoning,omitempty"` + Thinking *ThinkingProviderOption `json:"thinking,omitempty"` + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` } -type AnthropicThinkingProviderOption struct { +type ThinkingProviderOption struct { BudgetTokens int64 `json:"budget_tokens"` } -type AnthropicReasoningMetadata struct { +type ReasoningMetadata struct { Signature string `json:"signature"` RedactedData string `json:"redacted_data"` } -type AnthropicCacheControlProviderOptions struct { +type CacheControlProviderOptions struct { Type string `json:"type"` } -type AnthropicFilePartProviderOptions struct { +type FilePartProviderOptions struct { EnableCitations bool `json:"enable_citations"` Title string `json:"title"` Context string `json:"context"` } -type anthropicProviderOptions struct { +type options struct { baseURL string apiKey string name string @@ -48,14 +48,14 @@ type anthropicProviderOptions struct { client option.HTTPClient } -type anthropicProvider struct { - options anthropicProviderOptions +type provider struct { + options options } -type AnthropicOption = func(*anthropicProviderOptions) +type Option = func(*options) -func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider { - options := anthropicProviderOptions{ +func New(opts ...Option) ai.Provider { + options := options{ headers: map[string]string{}, } for _, o := range opts { @@ -69,42 +69,42 @@ func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider { options.name = "anthropic" } - return &anthropicProvider{ + return &provider{ options: options, } } -func WithAnthropicBaseURL(baseURL string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithBaseURL(baseURL string) Option { + return func(o *options) { o.baseURL = baseURL } } -func WithAnthropicAPIKey(apiKey string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithAPIKey(apiKey string) Option { + return func(o *options) { o.apiKey = apiKey } } -func WithAnthropicName(name string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithName(name string) Option { + return func(o *options) { o.name = name } } -func WithAnthropicHeaders(headers map[string]string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithHeaders(headers map[string]string) Option { + return func(o *options) { maps.Copy(o.headers, headers) } } -func WithAnthropicHTTPClient(client option.HTTPClient) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithHTTPClient(client option.HTTPClient) Option { + return func(o *options) { o.client = client } } -func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, error) { +func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { anthropicClientOptions := []option.RequestOption{} if a.options.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey)) @@ -120,34 +120,34 @@ func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, err if a.options.client != nil { anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client)) } - return anthropicLanguageModel{ - modelID: modelID, - provider: fmt.Sprintf("%s.messages", a.options.name), - providerOptions: a.options, - client: anthropic.NewClient(anthropicClientOptions...), + return languageModel{ + modelID: modelID, + provider: fmt.Sprintf("%s.messages", a.options.name), + options: a.options, + client: anthropic.NewClient(anthropicClientOptions...), }, nil } -type anthropicLanguageModel struct { - provider string - modelID string - client anthropic.Client - providerOptions anthropicProviderOptions +type languageModel struct { + provider string + modelID string + client anthropic.Client + options options } // Model implements ai.LanguageModel. -func (a anthropicLanguageModel) Model() string { +func (a languageModel) Model() string { return a.modelID } // Provider implements ai.LanguageModel. -func (a anthropicLanguageModel) Provider() string { +func (a languageModel) Provider() string { return a.provider } -func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { +func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { params := &anthropic.MessageNewParams{} - providerOptions := &AnthropicProviderOptions{} + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["anthropic"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -158,7 +158,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN if providerOptions.SendReasoning != nil { sendReasoning = *providerOptions.SendReasoning } - systemBlocks, messages, warnings := toAnthropicPrompt(call.Prompt, sendReasoning) + systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning) if call.FrequencyPenalty != nil { warnings = append(warnings, ai.CallWarning{ @@ -235,7 +235,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN if providerOptions.DisableParallelToolUse != nil { disableParallelToolUse = *providerOptions.DisableParallelToolUse } - tools, toolChoice, toolWarnings := toAnthropicTools(call.Tools, call.ToolChoice, disableParallelToolUse) + tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse) params.Tools = tools if toolChoice != nil { params.ToolChoice = *toolChoice @@ -246,11 +246,11 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN return params, warnings, nil } -func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlProviderOptions { +func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions { if anthropicOptions, ok := providerOptions["anthropic"]; ok { if cacheControl, ok := anthropicOptions["cache_control"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &AnthropicCacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) if err != nil { return cacheControlOption @@ -258,7 +258,7 @@ func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlP } } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &AnthropicCacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) if err != nil { return cacheControlOption @@ -269,9 +269,9 @@ func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlP return nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *AnthropicReasoningMetadata { +func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata { if anthropicOptions, ok := providerOptions["anthropic"]; ok { - reasoningMetadata := &AnthropicReasoningMetadata{} + reasoningMetadata := &ReasoningMetadata{} err := ai.ParseOptions(anthropicOptions, reasoningMetadata) if err != nil { return reasoningMetadata @@ -333,7 +333,7 @@ func groupIntoBlocks(prompt ai.Prompt) []*messageBlock { return blocks } -func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) { +func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) { for _, tool := range tools { if tool.GetType() == ai.ToolTypeFunction { ft, ok := tool.(ai.FunctionTool) @@ -414,7 +414,7 @@ func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParalle return anthropicTools, anthropicToolChoice, warnings } -func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) { +func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) { var systemBlocks []anthropic.TextBlockParam var messages []anthropic.MessageParam var warnings []ai.CallWarning @@ -638,7 +638,7 @@ func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.Te return systemBlocks, messages, warnings } -func (o anthropicLanguageModel) handleError(err error) error { +func (o languageModel) handleError(err error) error { var apiErr *anthropic.Error if errors.As(err, &apiErr) { requestDump := apiErr.DumpRequest(true) @@ -662,7 +662,7 @@ func (o anthropicLanguageModel) handleError(err error) error { return err } -func mapAnthropicFinishReason(finishReason string) ai.FinishReason { +func mapFinishReason(finishReason string) ai.FinishReason { switch finishReason { case "end", "stop_sequence": return ai.FinishReasonStop @@ -676,7 +676,7 @@ func mapAnthropicFinishReason(finishReason string) ai.FinishReason { } // Generate implements ai.LanguageModel. -func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -746,7 +746,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai CacheCreationTokens: response.Usage.CacheCreationInputTokens, CacheReadTokens: response.Usage.CacheReadInputTokens, }, - FinishReason: mapAnthropicFinishReason(string(response.StopReason)), + FinishReason: mapFinishReason(string(response.StopReason)), ProviderMetadata: ai.ProviderMetadata{ "anthropic": make(map[string]any), }, @@ -755,7 +755,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai } // Stream implements ai.LanguageModel. -func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -904,7 +904,7 @@ func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St yield(ai.StreamPart{ Type: ai.StreamPartTypeFinish, ID: acc.ID, - FinishReason: mapAnthropicFinishReason(string(acc.StopReason)), + FinishReason: mapFinishReason(string(acc.StopReason)), Usage: ai.Usage{ InputTokens: acc.Usage.InputTokens, OutputTokens: acc.Usage.OutputTokens, diff --git a/providers/openai.go b/providers/openai/openai.go similarity index 94% rename from providers/openai.go rename to providers/openai/openai.go index e3d5594c964da67a22511d02a3b364da5bcfb81f..f7f152eb423c39ea43ea4d276c68c259dbdd8aff 100644 --- a/providers/openai.go +++ b/providers/openai/openai.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "context" @@ -28,7 +28,7 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) -type OpenAiProviderOptions struct { +type ProviderOptions struct { LogitBias map[string]int64 `json:"logit_bias"` LogProbs *bool `json:"log_probes"` TopLogProbs *int64 `json:"top_log_probs"` @@ -46,11 +46,11 @@ type OpenAiProviderOptions struct { StructuredOutputs *bool `json:"structured_outputs"` } -type openAiProvider struct { - options openAiProviderOptions +type provider struct { + options options } -type openAiProviderOptions struct { +type options struct { baseURL string apiKey string organization string @@ -60,10 +60,10 @@ type openAiProviderOptions struct { client option.HTTPClient } -type OpenAiOption = func(*openAiProviderOptions) +type Option = func(*options) -func NewOpenAiProvider(opts ...OpenAiOption) ai.Provider { - options := openAiProviderOptions{ +func New(opts ...Option) ai.Provider { + options := options{ headers: map[string]string{}, } for _, o := range opts { @@ -86,55 +86,55 @@ func NewOpenAiProvider(opts ...OpenAiOption) ai.Provider { options.headers["OpenAi-Project"] = options.project } - return &openAiProvider{ + return &provider{ options: options, } } -func WithOpenAiBaseURL(baseURL string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithBaseURL(baseURL string) Option { + return func(o *options) { o.baseURL = baseURL } } -func WithOpenAiAPIKey(apiKey string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithAPIKey(apiKey string) Option { + return func(o *options) { o.apiKey = apiKey } } -func WithOpenAiOrganization(organization string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithOrganization(organization string) Option { + return func(o *options) { o.organization = organization } } -func WithOpenAiProject(project string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithProject(project string) Option { + return func(o *options) { o.project = project } } -func WithOpenAiName(name string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithName(name string) Option { + return func(o *options) { o.name = name } } -func WithOpenAiHeaders(headers map[string]string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithHeaders(headers map[string]string) Option { + return func(o *options) { maps.Copy(o.headers, headers) } } -func WithOpenAiHTTPClient(client option.HTTPClient) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithHTTPClient(client option.HTTPClient) Option { + return func(o *options) { o.client = client } } // LanguageModel implements ai.Provider. -func (o *openAiProvider) LanguageModel(modelID string) (ai.LanguageModel, error) { +func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { openaiClientOptions := []option.RequestOption{} if o.options.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey)) @@ -151,35 +151,35 @@ func (o *openAiProvider) LanguageModel(modelID string) (ai.LanguageModel, error) openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client)) } - return openAiLanguageModel{ - modelID: modelID, - provider: fmt.Sprintf("%s.chat", o.options.name), - providerOptions: o.options, - client: openai.NewClient(openaiClientOptions...), + return languageModel{ + modelID: modelID, + provider: fmt.Sprintf("%s.chat", o.options.name), + options: o.options, + client: openai.NewClient(openaiClientOptions...), }, nil } -type openAiLanguageModel struct { - provider string - modelID string - client openai.Client - providerOptions openAiProviderOptions +type languageModel struct { + provider string + modelID string + client openai.Client + options options } // Model implements ai.LanguageModel. -func (o openAiLanguageModel) Model() string { +func (o languageModel) Model() string { return o.modelID } // Provider implements ai.LanguageModel. -func (o openAiLanguageModel) Provider() string { +func (o languageModel) Provider() string { return o.provider } -func (o openAiLanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { +func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { params := &openai.ChatCompletionNewParams{} - messages, warnings := toOpenAiPrompt(call.Prompt) - providerOptions := &OpenAiProviderOptions{} + messages, warnings := toPrompt(call.Prompt) + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["openai"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -398,7 +398,7 @@ func (o openAiLanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletion return params, warnings, nil } -func (o openAiLanguageModel) handleError(err error) error { +func (o languageModel) handleError(err error) error { var apiErr *openai.Error if errors.As(err, &apiErr) { requestDump := apiErr.DumpRequest(true) @@ -423,7 +423,7 @@ func (o openAiLanguageModel) handleError(err error) error { } // Generate implements ai.LanguageModel. -func (o openAiLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -515,7 +515,7 @@ type toolCall struct { } // Stream implements ai.LanguageModel. -func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -865,7 +865,7 @@ func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []op return openAiTools, openAiToolChoice, warnings } -func toOpenAiPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { +func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { var messages []openai.ChatCompletionMessageParamUnion var warnings []ai.CallWarning for _, msg := range prompt { diff --git a/providers/openai_test.go b/providers/openai/openai_test.go similarity index 92% rename from providers/openai_test.go rename to providers/openai/openai_test.go index f5c0bc5240a74f0d66428fe4ea0f6e82c8b72f86..a0c9ae9e9af178f336166337439eeea974a99caf 100644 --- a/providers/openai_test.go +++ b/providers/openai/openai_test.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "context" @@ -30,7 +30,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -50,7 +50,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Len(t, warnings, 1) require.Contains(t, warnings[0].Message, "system prompt has no text parts") @@ -70,7 +70,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -96,7 +96,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -123,7 +123,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -167,7 +167,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -202,7 +202,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Len(t, warnings, 1) require.Contains(t, warnings[0].Message, "file part media type application/something not supported") @@ -225,7 +225,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -258,7 +258,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -286,7 +286,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -315,7 +315,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -349,7 +349,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -378,7 +378,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -408,7 +408,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -457,7 +457,7 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 2) @@ -504,7 +504,7 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 2) @@ -538,7 +538,7 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -568,7 +568,7 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -811,9 +811,9 @@ func TestDoGenerate(t *testing.T) { "content": "Hello, World!", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -843,9 +843,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -867,9 +867,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -907,9 +907,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -933,9 +933,9 @@ func TestDoGenerate(t *testing.T) { "logprobs": testLogprobs, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -969,9 +969,9 @@ func TestDoGenerate(t *testing.T) { "finish_reason": "stop", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -993,9 +993,9 @@ func TestDoGenerate(t *testing.T) { "finish_reason": "eos", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1017,9 +1017,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1049,9 +1049,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1093,9 +1093,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-mini") @@ -1133,9 +1133,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o") @@ -1173,9 +1173,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1245,9 +1245,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1303,9 +1303,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1345,9 +1345,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1380,9 +1380,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1407,9 +1407,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1455,9 +1455,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1499,9 +1499,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1526,9 +1526,9 @@ func TestDoGenerate(t *testing.T) { "model": "o1-preview", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1566,9 +1566,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1612,9 +1612,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1652,9 +1652,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1696,9 +1696,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1736,9 +1736,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1774,9 +1774,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-search-preview") @@ -1808,9 +1808,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o3-mini") @@ -1846,9 +1846,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1881,9 +1881,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1919,9 +1919,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2228,9 +2228,9 @@ func TestDoStream(t *testing.T) { "logprobs": testLogprobs, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2284,9 +2284,9 @@ func TestDoStream(t *testing.T) { server.prepareToolStreamResponse() - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2370,9 +2370,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2409,9 +2409,9 @@ func TestDoStream(t *testing.T) { server.prepareErrorStreamResponse() - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2450,9 +2450,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2498,9 +2498,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2548,9 +2548,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2591,9 +2591,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2635,9 +2635,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2683,9 +2683,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o3-mini") @@ -2727,9 +2727,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -2772,9 +2772,9 @@ func TestDoStream(t *testing.T) { "model": "o1-preview", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -2818,9 +2818,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview")