diff --git a/examples/agent/main.go b/examples/agent/main.go index e519e6cb6fc58aca33b4bc0f7f2971a69f63d1c5..97364783f632a5f8623abd3e550142aa70c3f8fe 100644 --- a/examples/agent/main.go +++ b/examples/agent/main.go @@ -6,12 +6,12 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { - provider := providers.NewOpenAiProvider( - providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")), + provider := openai.New( + openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")), ) model, err := provider.LanguageModel("gpt-4o") if err != nil { diff --git a/examples/simple/main.go b/examples/simple/main.go index c9fcfd56cd54efbeb42deb8ddeb85af772364c5b..5aea35047fda21a9266c6f17b613b9d75ba90ccc 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -6,11 +6,11 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/anthropic" ) func main() { - provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) + provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) model, err := provider.LanguageModel("claude-sonnet-4-20250514") if err != nil { fmt.Println(err) diff --git a/examples/stream/main.go b/examples/stream/main.go index 45583590ba6e8376d35882b4e747ed31f4b1a6b4..1eb3c9ee18f31f03229a34ed194cd703c899f2ef 100644 --- a/examples/stream/main.go +++ b/examples/stream/main.go @@ -7,11 +7,11 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { - provider := providers.NewOpenAiProvider(providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY"))) + provider := openai.New(openai.WithAPIKey(os.Getenv("OPENAI_API_KEY"))) model, err := provider.LanguageModel("gpt-4o") if err != nil { fmt.Println(err) diff --git a/examples/streaming-agent-simple/main.go b/examples/streaming-agent-simple/main.go index 8184f8951c90b2d4cc05590ecd27d71d6bccbec0..5f94d0a30cdfd518de1afbb05b79448ab34e2842 100644 --- a/examples/streaming-agent-simple/main.go +++ b/examples/streaming-agent-simple/main.go @@ -6,7 +6,7 @@ import ( "os" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/openai" ) func main() { @@ -18,8 +18,8 @@ func main() { } // Create provider and model - provider := providers.NewOpenAiProvider( - providers.WithOpenAiAPIKey(apiKey), + provider := openai.New( + openai.WithAPIKey(apiKey), ) model, err := provider.LanguageModel("gpt-4o-mini") if err != nil { diff --git a/examples/streaming-agent/main.go b/examples/streaming-agent/main.go index 988d23d512707f5025eaab63856217fcc8efa558..fb4fa400599e3814bd8eccf117467b0bebcaeb32 100644 --- a/examples/streaming-agent/main.go +++ b/examples/streaming-agent/main.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/charmbracelet/ai" - "github.com/charmbracelet/ai/providers" + "github.com/charmbracelet/ai/providers/anthropic" ) func main() { @@ -24,7 +24,7 @@ func main() { fmt.Println() // Create OpenAI provider and model - provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) + provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY"))) model, err := provider.LanguageModel("claude-sonnet-4-20250514") if err != nil { fmt.Println(err) diff --git a/providers/anthropic.go b/providers/anthropic/anthropic.go similarity index 87% rename from providers/anthropic.go rename to providers/anthropic/anthropic.go index a8924a09db13db7e65fa0c35fd695992c1d39a3a..fbc1381f20b6d8a192bea1d1e2411234e9fe4439 100644 --- a/providers/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -1,4 +1,4 @@ -package providers +package anthropic import ( "context" @@ -16,31 +16,31 @@ import ( "github.com/charmbracelet/ai" ) -type AnthropicProviderOptions struct { - SendReasoning *bool `json:"send_reasoning,omitempty"` - Thinking *AnthropicThinkingProviderOption `json:"thinking,omitempty"` - DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` +type ProviderOptions struct { + SendReasoning *bool `json:"send_reasoning,omitempty"` + Thinking *ThinkingProviderOption `json:"thinking,omitempty"` + DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` } -type AnthropicThinkingProviderOption struct { +type ThinkingProviderOption struct { BudgetTokens int64 `json:"budget_tokens"` } -type AnthropicReasoningMetadata struct { +type ReasoningMetadata struct { Signature string `json:"signature"` RedactedData string `json:"redacted_data"` } -type AnthropicCacheControlProviderOptions struct { +type CacheControlProviderOptions struct { Type string `json:"type"` } -type AnthropicFilePartProviderOptions struct { +type FilePartProviderOptions struct { EnableCitations bool `json:"enable_citations"` Title string `json:"title"` Context string `json:"context"` } -type anthropicProviderOptions struct { +type options struct { baseURL string apiKey string name string @@ -48,14 +48,14 @@ type anthropicProviderOptions struct { client option.HTTPClient } -type anthropicProvider struct { - options anthropicProviderOptions +type provider struct { + options options } -type AnthropicOption = func(*anthropicProviderOptions) +type Option = func(*options) -func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider { - options := anthropicProviderOptions{ +func New(opts ...Option) ai.Provider { + options := options{ headers: map[string]string{}, } for _, o := range opts { @@ -69,42 +69,42 @@ func NewAnthropicProvider(opts ...AnthropicOption) ai.Provider { options.name = "anthropic" } - return &anthropicProvider{ + return &provider{ options: options, } } -func WithAnthropicBaseURL(baseURL string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithBaseURL(baseURL string) Option { + return func(o *options) { o.baseURL = baseURL } } -func WithAnthropicAPIKey(apiKey string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithAPIKey(apiKey string) Option { + return func(o *options) { o.apiKey = apiKey } } -func WithAnthropicName(name string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithName(name string) Option { + return func(o *options) { o.name = name } } -func WithAnthropicHeaders(headers map[string]string) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithHeaders(headers map[string]string) Option { + return func(o *options) { maps.Copy(o.headers, headers) } } -func WithAnthropicHTTPClient(client option.HTTPClient) AnthropicOption { - return func(o *anthropicProviderOptions) { +func WithHTTPClient(client option.HTTPClient) Option { + return func(o *options) { o.client = client } } -func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, error) { +func (a *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { anthropicClientOptions := []option.RequestOption{} if a.options.apiKey != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(a.options.apiKey)) @@ -120,34 +120,34 @@ func (a *anthropicProvider) LanguageModel(modelID string) (ai.LanguageModel, err if a.options.client != nil { anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(a.options.client)) } - return anthropicLanguageModel{ - modelID: modelID, - provider: fmt.Sprintf("%s.messages", a.options.name), - providerOptions: a.options, - client: anthropic.NewClient(anthropicClientOptions...), + return languageModel{ + modelID: modelID, + provider: fmt.Sprintf("%s.messages", a.options.name), + options: a.options, + client: anthropic.NewClient(anthropicClientOptions...), }, nil } -type anthropicLanguageModel struct { - provider string - modelID string - client anthropic.Client - providerOptions anthropicProviderOptions +type languageModel struct { + provider string + modelID string + client anthropic.Client + options options } // Model implements ai.LanguageModel. -func (a anthropicLanguageModel) Model() string { +func (a languageModel) Model() string { return a.modelID } // Provider implements ai.LanguageModel. -func (a anthropicLanguageModel) Provider() string { +func (a languageModel) Provider() string { return a.provider } -func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { +func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) { params := &anthropic.MessageNewParams{} - providerOptions := &AnthropicProviderOptions{} + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["anthropic"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -158,7 +158,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN if providerOptions.SendReasoning != nil { sendReasoning = *providerOptions.SendReasoning } - systemBlocks, messages, warnings := toAnthropicPrompt(call.Prompt, sendReasoning) + systemBlocks, messages, warnings := toPrompt(call.Prompt, sendReasoning) if call.FrequencyPenalty != nil { warnings = append(warnings, ai.CallWarning{ @@ -235,7 +235,7 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN if providerOptions.DisableParallelToolUse != nil { disableParallelToolUse = *providerOptions.DisableParallelToolUse } - tools, toolChoice, toolWarnings := toAnthropicTools(call.Tools, call.ToolChoice, disableParallelToolUse) + tools, toolChoice, toolWarnings := toTools(call.Tools, call.ToolChoice, disableParallelToolUse) params.Tools = tools if toolChoice != nil { params.ToolChoice = *toolChoice @@ -246,11 +246,11 @@ func (a anthropicLanguageModel) prepareParams(call ai.Call) (*anthropic.MessageN return params, warnings, nil } -func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlProviderOptions { +func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions { if anthropicOptions, ok := providerOptions["anthropic"]; ok { if cacheControl, ok := anthropicOptions["cache_control"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &AnthropicCacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) if err != nil { return cacheControlOption @@ -258,7 +258,7 @@ func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlP } } else if cacheControl, ok := anthropicOptions["cacheControl"]; ok { if cc, ok := cacheControl.(map[string]any); ok { - cacheControlOption := &AnthropicCacheControlProviderOptions{} + cacheControlOption := &CacheControlProviderOptions{} err := ai.ParseOptions(cc, cacheControlOption) if err != nil { return cacheControlOption @@ -269,9 +269,9 @@ func getCacheControl(providerOptions ai.ProviderOptions) *AnthropicCacheControlP return nil } -func getReasoningMetadata(providerOptions ai.ProviderOptions) *AnthropicReasoningMetadata { +func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata { if anthropicOptions, ok := providerOptions["anthropic"]; ok { - reasoningMetadata := &AnthropicReasoningMetadata{} + reasoningMetadata := &ReasoningMetadata{} err := ai.ParseOptions(anthropicOptions, reasoningMetadata) if err != nil { return reasoningMetadata @@ -333,7 +333,7 @@ func groupIntoBlocks(prompt ai.Prompt) []*messageBlock { return blocks } -func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) { +func toTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParallelToolCalls bool) (anthropicTools []anthropic.ToolUnionParam, anthropicToolChoice *anthropic.ToolChoiceUnionParam, warnings []ai.CallWarning) { for _, tool := range tools { if tool.GetType() == ai.ToolTypeFunction { ft, ok := tool.(ai.FunctionTool) @@ -414,7 +414,7 @@ func toAnthropicTools(tools []ai.Tool, toolChoice *ai.ToolChoice, disableParalle return anthropicTools, anthropicToolChoice, warnings } -func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) { +func toPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.TextBlockParam, []anthropic.MessageParam, []ai.CallWarning) { var systemBlocks []anthropic.TextBlockParam var messages []anthropic.MessageParam var warnings []ai.CallWarning @@ -638,7 +638,7 @@ func toAnthropicPrompt(prompt ai.Prompt, sendReasoningData bool) ([]anthropic.Te return systemBlocks, messages, warnings } -func (o anthropicLanguageModel) handleError(err error) error { +func (o languageModel) handleError(err error) error { var apiErr *anthropic.Error if errors.As(err, &apiErr) { requestDump := apiErr.DumpRequest(true) @@ -662,7 +662,7 @@ func (o anthropicLanguageModel) handleError(err error) error { return err } -func mapAnthropicFinishReason(finishReason string) ai.FinishReason { +func mapFinishReason(finishReason string) ai.FinishReason { switch finishReason { case "end", "stop_sequence": return ai.FinishReasonStop @@ -676,7 +676,7 @@ func mapAnthropicFinishReason(finishReason string) ai.FinishReason { } // Generate implements ai.LanguageModel. -func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +func (a languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -746,7 +746,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai CacheCreationTokens: response.Usage.CacheCreationInputTokens, CacheReadTokens: response.Usage.CacheReadInputTokens, }, - FinishReason: mapAnthropicFinishReason(string(response.StopReason)), + FinishReason: mapFinishReason(string(response.StopReason)), ProviderMetadata: ai.ProviderMetadata{ "anthropic": make(map[string]any), }, @@ -755,7 +755,7 @@ func (a anthropicLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai } // Stream implements ai.LanguageModel. -func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { params, warnings, err := a.prepareParams(call) if err != nil { return nil, err @@ -904,7 +904,7 @@ func (a anthropicLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.St yield(ai.StreamPart{ Type: ai.StreamPartTypeFinish, ID: acc.ID, - FinishReason: mapAnthropicFinishReason(string(acc.StopReason)), + FinishReason: mapFinishReason(string(acc.StopReason)), Usage: ai.Usage{ InputTokens: acc.Usage.InputTokens, OutputTokens: acc.Usage.OutputTokens, diff --git a/providers/openai.go b/providers/openai/openai.go similarity index 94% rename from providers/openai.go rename to providers/openai/openai.go index e3d5594c964da67a22511d02a3b364da5bcfb81f..f7f152eb423c39ea43ea4d276c68c259dbdd8aff 100644 --- a/providers/openai.go +++ b/providers/openai/openai.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "context" @@ -28,7 +28,7 @@ const ( ReasoningEffortHigh ReasoningEffort = "high" ) -type OpenAiProviderOptions struct { +type ProviderOptions struct { LogitBias map[string]int64 `json:"logit_bias"` LogProbs *bool `json:"log_probes"` TopLogProbs *int64 `json:"top_log_probs"` @@ -46,11 +46,11 @@ type OpenAiProviderOptions struct { StructuredOutputs *bool `json:"structured_outputs"` } -type openAiProvider struct { - options openAiProviderOptions +type provider struct { + options options } -type openAiProviderOptions struct { +type options struct { baseURL string apiKey string organization string @@ -60,10 +60,10 @@ type openAiProviderOptions struct { client option.HTTPClient } -type OpenAiOption = func(*openAiProviderOptions) +type Option = func(*options) -func NewOpenAiProvider(opts ...OpenAiOption) ai.Provider { - options := openAiProviderOptions{ +func New(opts ...Option) ai.Provider { + options := options{ headers: map[string]string{}, } for _, o := range opts { @@ -86,55 +86,55 @@ func NewOpenAiProvider(opts ...OpenAiOption) ai.Provider { options.headers["OpenAi-Project"] = options.project } - return &openAiProvider{ + return &provider{ options: options, } } -func WithOpenAiBaseURL(baseURL string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithBaseURL(baseURL string) Option { + return func(o *options) { o.baseURL = baseURL } } -func WithOpenAiAPIKey(apiKey string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithAPIKey(apiKey string) Option { + return func(o *options) { o.apiKey = apiKey } } -func WithOpenAiOrganization(organization string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithOrganization(organization string) Option { + return func(o *options) { o.organization = organization } } -func WithOpenAiProject(project string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithProject(project string) Option { + return func(o *options) { o.project = project } } -func WithOpenAiName(name string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithName(name string) Option { + return func(o *options) { o.name = name } } -func WithOpenAiHeaders(headers map[string]string) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithHeaders(headers map[string]string) Option { + return func(o *options) { maps.Copy(o.headers, headers) } } -func WithOpenAiHTTPClient(client option.HTTPClient) OpenAiOption { - return func(o *openAiProviderOptions) { +func WithHTTPClient(client option.HTTPClient) Option { + return func(o *options) { o.client = client } } // LanguageModel implements ai.Provider. -func (o *openAiProvider) LanguageModel(modelID string) (ai.LanguageModel, error) { +func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { openaiClientOptions := []option.RequestOption{} if o.options.apiKey != "" { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey)) @@ -151,35 +151,35 @@ func (o *openAiProvider) LanguageModel(modelID string) (ai.LanguageModel, error) openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client)) } - return openAiLanguageModel{ - modelID: modelID, - provider: fmt.Sprintf("%s.chat", o.options.name), - providerOptions: o.options, - client: openai.NewClient(openaiClientOptions...), + return languageModel{ + modelID: modelID, + provider: fmt.Sprintf("%s.chat", o.options.name), + options: o.options, + client: openai.NewClient(openaiClientOptions...), }, nil } -type openAiLanguageModel struct { - provider string - modelID string - client openai.Client - providerOptions openAiProviderOptions +type languageModel struct { + provider string + modelID string + client openai.Client + options options } // Model implements ai.LanguageModel. -func (o openAiLanguageModel) Model() string { +func (o languageModel) Model() string { return o.modelID } // Provider implements ai.LanguageModel. -func (o openAiLanguageModel) Provider() string { +func (o languageModel) Provider() string { return o.provider } -func (o openAiLanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { +func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) { params := &openai.ChatCompletionNewParams{} - messages, warnings := toOpenAiPrompt(call.Prompt) - providerOptions := &OpenAiProviderOptions{} + messages, warnings := toPrompt(call.Prompt) + providerOptions := &ProviderOptions{} if v, ok := call.ProviderOptions["openai"]; ok { err := ai.ParseOptions(v, providerOptions) if err != nil { @@ -398,7 +398,7 @@ func (o openAiLanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletion return params, warnings, nil } -func (o openAiLanguageModel) handleError(err error) error { +func (o languageModel) handleError(err error) error { var apiErr *openai.Error if errors.As(err, &apiErr) { requestDump := apiErr.DumpRequest(true) @@ -423,7 +423,7 @@ func (o openAiLanguageModel) handleError(err error) error { } // Generate implements ai.LanguageModel. -func (o openAiLanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { +func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -515,7 +515,7 @@ type toolCall struct { } // Stream implements ai.LanguageModel. -func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { +func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { params, warnings, err := o.prepareParams(call) if err != nil { return nil, err @@ -865,7 +865,7 @@ func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []op return openAiTools, openAiToolChoice, warnings } -func toOpenAiPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { +func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) { var messages []openai.ChatCompletionMessageParamUnion var warnings []ai.CallWarning for _, msg := range prompt { diff --git a/providers/openai_test.go b/providers/openai/openai_test.go similarity index 92% rename from providers/openai_test.go rename to providers/openai/openai_test.go index f5c0bc5240a74f0d66428fe4ea0f6e82c8b72f86..a0c9ae9e9af178f336166337439eeea974a99caf 100644 --- a/providers/openai_test.go +++ b/providers/openai/openai_test.go @@ -1,4 +1,4 @@ -package providers +package openai import ( "context" @@ -30,7 +30,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -50,7 +50,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Len(t, warnings, 1) require.Contains(t, warnings[0].Message, "system prompt has no text parts") @@ -70,7 +70,7 @@ func TestToOpenAiPrompt_SystemMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -96,7 +96,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -123,7 +123,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -167,7 +167,7 @@ func TestToOpenAiPrompt_UserMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -202,7 +202,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Len(t, warnings, 1) require.Contains(t, warnings[0].Message, "file part media type application/something not supported") @@ -225,7 +225,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -258,7 +258,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -286,7 +286,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -315,7 +315,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -349,7 +349,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -378,7 +378,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -408,7 +408,7 @@ func TestToOpenAiPrompt_FileParts(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -457,7 +457,7 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 2) @@ -504,7 +504,7 @@ func TestToOpenAiPrompt_ToolCalls(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 2) @@ -538,7 +538,7 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -568,7 +568,7 @@ func TestToOpenAiPrompt_AssistantMessages(t *testing.T) { }, } - messages, warnings := toOpenAiPrompt(prompt) + messages, warnings := toPrompt(prompt) require.Empty(t, warnings) require.Len(t, messages, 1) @@ -811,9 +811,9 @@ func TestDoGenerate(t *testing.T) { "content": "Hello, World!", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -843,9 +843,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -867,9 +867,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -907,9 +907,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -933,9 +933,9 @@ func TestDoGenerate(t *testing.T) { "logprobs": testLogprobs, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -969,9 +969,9 @@ func TestDoGenerate(t *testing.T) { "finish_reason": "stop", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -993,9 +993,9 @@ func TestDoGenerate(t *testing.T) { "finish_reason": "eos", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1017,9 +1017,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1049,9 +1049,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1093,9 +1093,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-mini") @@ -1133,9 +1133,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o") @@ -1173,9 +1173,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1245,9 +1245,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1303,9 +1303,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1345,9 +1345,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1380,9 +1380,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1407,9 +1407,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1455,9 +1455,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1499,9 +1499,9 @@ func TestDoGenerate(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1526,9 +1526,9 @@ func TestDoGenerate(t *testing.T) { "model": "o1-preview", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -1566,9 +1566,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1612,9 +1612,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1652,9 +1652,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1696,9 +1696,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1736,9 +1736,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -1774,9 +1774,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-search-preview") @@ -1808,9 +1808,9 @@ func TestDoGenerate(t *testing.T) { "content": "", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o3-mini") @@ -1846,9 +1846,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1881,9 +1881,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -1919,9 +1919,9 @@ func TestDoGenerate(t *testing.T) { server.prepareJSONResponse(map[string]any{}) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2228,9 +2228,9 @@ func TestDoStream(t *testing.T) { "logprobs": testLogprobs, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2284,9 +2284,9 @@ func TestDoStream(t *testing.T) { server.prepareToolStreamResponse() - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2370,9 +2370,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2409,9 +2409,9 @@ func TestDoStream(t *testing.T) { server.prepareErrorStreamResponse() - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2450,9 +2450,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2498,9 +2498,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2548,9 +2548,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2591,9 +2591,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2635,9 +2635,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-3.5-turbo") @@ -2683,9 +2683,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o3-mini") @@ -2727,9 +2727,9 @@ func TestDoStream(t *testing.T) { "content": []string{}, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("gpt-4o-mini") @@ -2772,9 +2772,9 @@ func TestDoStream(t *testing.T) { "model": "o1-preview", }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview") @@ -2818,9 +2818,9 @@ func TestDoStream(t *testing.T) { }, }) - provider := NewOpenAiProvider( - WithOpenAiAPIKey("test-api-key"), - WithOpenAiBaseURL(server.server.URL), + provider := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.server.URL), ) model, _ := provider.LanguageModel("o1-preview")