@@ -0,0 +1,1161 @@
+package providers
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/env"
+ "github.com/google/uuid"
+ "github.com/openai/openai-go/v2"
+ "github.com/openai/openai-go/v2/option"
+ "github.com/openai/openai-go/v2/packages/param"
+ "github.com/openai/openai-go/v2/shared"
+)
+
+type ReasoningEffort string
+
+const (
+ ReasoningEffortMinimal ReasoningEffort = "minimal"
+ ReasoningEffortLow ReasoningEffort = "low"
+ ReasoningEffortMedium ReasoningEffort = "medium"
+ ReasoningEffortHigh ReasoningEffort = "high"
+)
+
+type OpenAIProviderOptions struct {
+ LogitBias map[string]int64 `json:"logit_bias"`
+ LogProbs *bool `json:"log_probes"`
+ TopLogProbs *int64 `json:"top_log_probs"`
+ ParallelToolCalls *bool `json:"parallel_tool_calls"`
+ User *string `json:"user"`
+ ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
+ MaxCompletionTokens *int64 `json:"max_completion_tokens"`
+ TextVerbosity *string `json:"text_verbosity"`
+ Prediction map[string]any `json:"prediction"`
+ Store *bool `json:"store"`
+ Metadata map[string]any `json:"metadata"`
+ PromptCacheKey *string `json:"prompt_cache_key"`
+ SafetyIdentifier *string `json:"safety_identifier"`
+ ServiceTier *string `json:"service_tier"`
+ StructuredOutputs *bool `json:"structured_outputs"`
+}
+
+type openAIProvider struct {
+ options openAIProviderOptions
+}
+
+type openAIProviderOptions struct {
+ baseURL string
+ apiKey string
+ organization string
+ project string
+ name string
+ headers map[string]string
+ client option.HTTPClient
+ resolver config.VariableResolver
+}
+
+type OpenAIOption = func(*openAIProviderOptions)
+
+func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
+ options := openAIProviderOptions{
+ headers: map[string]string{},
+ }
+ for _, o := range opts {
+ o(&options)
+ }
+
+ if options.resolver == nil {
+ // use the default resolver
+ options.resolver = config.NewShellVariableResolver(env.New())
+ }
+ options.apiKey, _ = options.resolver.ResolveValue(options.apiKey)
+ options.baseURL, _ = options.resolver.ResolveValue(options.baseURL)
+ if options.baseURL == "" {
+ options.baseURL = "https://api.openai.com/v1"
+ }
+
+ options.name, _ = options.resolver.ResolveValue(options.name)
+ if options.name == "" {
+ options.name = "openai"
+ }
+
+ for k, v := range options.headers {
+ options.headers[k], _ = options.resolver.ResolveValue(v)
+ }
+
+ options.organization, _ = options.resolver.ResolveValue(options.organization)
+ if options.organization != "" {
+ options.headers["OpenAI-Organization"] = options.organization
+ }
+
+ options.project, _ = options.resolver.ResolveValue(options.project)
+ if options.project != "" {
+ options.headers["OpenAI-Project"] = options.project
+ }
+
+ return &openAIProvider{
+ options: options,
+ }
+}
+
+func WithOpenAIBaseURL(baseURL string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.baseURL = baseURL
+ }
+}
+
+func WithOpenAIApiKey(apiKey string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.apiKey = apiKey
+ }
+}
+
+func WithOpenAIOrganization(organization string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.organization = organization
+ }
+}
+
+func WithOpenAIProject(project string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.project = project
+ }
+}
+
+func WithOpenAIName(name string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.name = name
+ }
+}
+
+func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ maps.Copy(o.headers, headers)
+ }
+}
+
+func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.client = client
+ }
+}
+
+func WithOpenAIVariableResolver(resolver config.VariableResolver) OpenAIOption {
+ return func(o *openAIProviderOptions) {
+ o.resolver = resolver
+ }
+}
+
+// LanguageModel implements ai.Provider.
+func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
+ openaiClientOptions := []option.RequestOption{}
+ if o.options.apiKey != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
+ }
+ if o.options.baseURL != "" {
+ openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
+ }
+
+ for key, value := range o.options.headers {
+ openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
+ }
+
+ if o.options.client != nil {
+ openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
+ }
+
+ return openAILanguageModel{
+ modelID: modelID,
+ provider: fmt.Sprintf("%s.chat", o.options.name),
+ providerOptions: o.options,
+ client: openai.NewClient(openaiClientOptions...),
+ }
+}
+
+type openAILanguageModel struct {
+ provider string
+ modelID string
+ client openai.Client
+ providerOptions openAIProviderOptions
+}
+
+// Model implements ai.LanguageModel.
+func (o openAILanguageModel) Model() string {
+ return o.modelID
+}
+
+// Provider implements ai.LanguageModel.
+func (o openAILanguageModel) Provider() string {
+ return o.provider
+}
+
+func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
+ params := &openai.ChatCompletionNewParams{}
+ messages, warnings := toOpenAIPrompt(call.Prompt)
+ providerOptions := &OpenAIProviderOptions{}
+ if v, ok := call.ProviderOptions["openai"]; ok {
+ err := ai.ParseOptions(v, providerOptions)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if call.TopK != nil {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "top_k",
+ })
+ }
+ params.Messages = messages
+ params.Model = o.modelID
+ if providerOptions.LogitBias != nil {
+ params.LogitBias = providerOptions.LogitBias
+ }
+ if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
+ providerOptions.LogProbs = nil
+ }
+ if providerOptions.LogProbs != nil {
+ params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
+ }
+ if providerOptions.TopLogProbs != nil {
+ params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
+ }
+ if providerOptions.User != nil {
+ params.User = param.NewOpt(*providerOptions.User)
+ }
+ if providerOptions.ParallelToolCalls != nil {
+ params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
+ }
+
+ if call.MaxOutputTokens != nil {
+ params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
+ }
+ if call.Temperature != nil {
+ params.Temperature = param.NewOpt(*call.Temperature)
+ }
+ if call.TopP != nil {
+ params.TopP = param.NewOpt(*call.TopP)
+ }
+ if call.FrequencyPenalty != nil {
+ params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
+ }
+ if call.PresencePenalty != nil {
+ params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
+ }
+
+ if providerOptions.MaxCompletionTokens != nil {
+ params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
+ }
+
+ if providerOptions.TextVerbosity != nil {
+ params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
+ }
+ if providerOptions.Prediction != nil {
+ // Convert map[string]any to ChatCompletionPredictionContentParam
+ if content, ok := providerOptions.Prediction["content"]; ok {
+ if contentStr, ok := content.(string); ok {
+ params.Prediction = openai.ChatCompletionPredictionContentParam{
+ Content: openai.ChatCompletionPredictionContentContentUnionParam{
+ OfString: param.NewOpt(contentStr),
+ },
+ }
+ }
+ }
+ }
+ if providerOptions.Store != nil {
+ params.Store = param.NewOpt(*providerOptions.Store)
+ }
+ if providerOptions.Metadata != nil {
+ // Convert map[string]any to map[string]string
+ metadata := make(map[string]string)
+ for k, v := range providerOptions.Metadata {
+ if str, ok := v.(string); ok {
+ metadata[k] = str
+ }
+ }
+ params.Metadata = metadata
+ }
+ if providerOptions.PromptCacheKey != nil {
+ params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
+ }
+ if providerOptions.SafetyIdentifier != nil {
+ params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
+ }
+ if providerOptions.ServiceTier != nil {
+ params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
+ }
+
+ if providerOptions.ReasoningEffort != nil {
+ switch *providerOptions.ReasoningEffort {
+ case ReasoningEffortMinimal:
+ params.ReasoningEffort = shared.ReasoningEffortMinimal
+ case ReasoningEffortLow:
+ params.ReasoningEffort = shared.ReasoningEffortLow
+ case ReasoningEffortMedium:
+ params.ReasoningEffort = shared.ReasoningEffortMedium
+ case ReasoningEffortHigh:
+ params.ReasoningEffort = shared.ReasoningEffortHigh
+ default:
+ return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
+ }
+ }
+
+ if isReasoningModel(o.modelID) {
+ // remove unsupported settings for reasoning models
+ // see https://platform.openai.com/docs/guides/reasoning#limitations
+ if call.Temperature != nil {
+ params.Temperature = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "temperature",
+ Details: "temperature is not supported for reasoning models",
+ })
+ }
+ if call.TopP != nil {
+ params.TopP = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "top_p",
+ Details: "topP is not supported for reasoning models",
+ })
+ }
+ if call.FrequencyPenalty != nil {
+ params.FrequencyPenalty = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "frequency_penalty",
+ Details: "frequencyPenalty is not supported for reasoning models",
+ })
+ }
+ if call.PresencePenalty != nil {
+ params.PresencePenalty = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "presence_penalty",
+ Details: "presencePenalty is not supported for reasoning models",
+ })
+ }
+ if providerOptions.LogitBias != nil {
+ params.LogitBias = nil
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "logitBias is not supported for reasoning models",
+ })
+ }
+ if providerOptions.LogProbs != nil {
+ params.Logprobs = param.Opt[bool]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "logprobs is not supported for reasoning models",
+ })
+ }
+ if providerOptions.TopLogProbs != nil {
+ params.TopLogprobs = param.Opt[int64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "topLogprobs is not supported for reasoning models",
+ })
+ }
+
+ // reasoning models use max_completion_tokens instead of max_tokens
+ if call.MaxOutputTokens != nil {
+ if providerOptions.MaxCompletionTokens == nil {
+ params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
+ }
+ params.MaxTokens = param.Opt[int64]{}
+ }
+ }
+
+ // Handle search preview models
+ if isSearchPreviewModel(o.modelID) {
+ if call.Temperature != nil {
+ params.Temperature = param.Opt[float64]{}
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "temperature",
+ Details: "temperature is not supported for the search preview models and has been removed.",
+ })
+ }
+ }
+
+ // Handle service tier validation
+ if providerOptions.ServiceTier != nil {
+ serviceTier := *providerOptions.ServiceTier
+ if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
+ params.ServiceTier = ""
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "serviceTier",
+ Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
+ })
+ } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
+ params.ServiceTier = ""
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedSetting,
+ Setting: "serviceTier",
+ Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
+ })
+ }
+ }
+
+ if len(call.Tools) > 0 {
+ tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
+ params.Tools = tools
+ if toolChoice != nil {
+ params.ToolChoice = *toolChoice
+ }
+ warnings = append(warnings, toolWarnings...)
+ }
+ return params, warnings, nil
+}
+
+// Generate implements ai.LanguageModel.
+func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
+ params, warnings, err := o.prepareParams(call)
+ if err != nil {
+ return nil, err
+ }
+ response, err := o.client.Chat.Completions.New(ctx, *params)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(response.Choices) == 0 {
+ return nil, errors.New("no response generated")
+ }
+ choice := response.Choices[0]
+ var content []ai.Content
+ text := choice.Message.Content
+ if text != "" {
+ content = append(content, ai.TextContent{
+ Text: text,
+ })
+ }
+
+ for _, tc := range choice.Message.ToolCalls {
+ toolCallID := tc.ID
+ if toolCallID == "" {
+ toolCallID = uuid.NewString()
+ }
+ content = append(content, ai.ToolCallContent{
+ ProviderExecuted: false, // TODO: update when handling other tools
+ ToolCallID: toolCallID,
+ ToolName: tc.Function.Name,
+ Input: tc.Function.Arguments,
+ })
+ }
+ // Handle annotations/citations
+ for _, annotation := range choice.Message.Annotations {
+ if annotation.Type == "url_citation" {
+ content = append(content, ai.SourceContent{
+ SourceType: ai.SourceTypeURL,
+ ID: uuid.NewString(),
+ URL: annotation.URLCitation.URL,
+ Title: annotation.URLCitation.Title,
+ })
+ }
+ }
+
+ completionTokenDetails := response.Usage.CompletionTokensDetails
+ promptTokenDetails := response.Usage.PromptTokensDetails
+
+ // Build provider metadata
+ providerMetadata := ai.ProviderMetadata{
+ "openai": make(map[string]any),
+ }
+
+ // Add logprobs if available
+ if len(choice.Logprobs.Content) > 0 {
+ providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
+ }
+
+ // Add prediction tokens if available
+ if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
+ if completionTokenDetails.AcceptedPredictionTokens > 0 {
+ providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+ }
+ if completionTokenDetails.RejectedPredictionTokens > 0 {
+ providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+ }
+ }
+
+ return &ai.Response{
+ Content: content,
+ Usage: ai.Usage{
+ InputTokens: response.Usage.PromptTokens,
+ OutputTokens: response.Usage.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ ReasoningTokens: completionTokenDetails.ReasoningTokens,
+ CacheReadTokens: promptTokenDetails.CachedTokens,
+ },
+ FinishReason: mapOpenAIFinishReason(choice.FinishReason),
+ ProviderMetadata: providerMetadata,
+ Warnings: warnings,
+ }, nil
+}
+
+type toolCall struct {
+ id string
+ name string
+ arguments string
+ hasFinished bool
+}
+
+// Stream implements ai.LanguageModel.
+func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
+ params, warnings, err := o.prepareParams(call)
+ if err != nil {
+ return nil, err
+ }
+
+ params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
+ IncludeUsage: openai.Bool(true),
+ }
+
+ stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
+ isActiveText := false
+ toolCalls := make(map[int64]toolCall)
+
+ // Build provider metadata for streaming
+ streamProviderMetadata := ai.ProviderOptions{
+ "openai": make(map[string]any),
+ }
+
+ acc := openai.ChatCompletionAccumulator{}
+ var usage ai.Usage
+ return func(yield func(ai.StreamPart) bool) {
+ if len(warnings) > 0 {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeWarnings,
+ Warnings: warnings,
+ }) {
+ return
+ }
+ }
+ for stream.Next() {
+ chunk := stream.Current()
+ acc.AddChunk(chunk)
+ if chunk.Usage.TotalTokens > 0 {
+ // we do this here because the acc does not add prompt details
+ completionTokenDetails := chunk.Usage.CompletionTokensDetails
+ promptTokenDetails := chunk.Usage.PromptTokensDetails
+ usage = ai.Usage{
+ InputTokens: chunk.Usage.PromptTokens,
+ OutputTokens: chunk.Usage.CompletionTokens,
+ TotalTokens: chunk.Usage.TotalTokens,
+ ReasoningTokens: completionTokenDetails.ReasoningTokens,
+ CacheReadTokens: promptTokenDetails.CachedTokens,
+ }
+
+ // Add prediction tokens if available
+ if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
+ if completionTokenDetails.AcceptedPredictionTokens > 0 {
+ streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
+ }
+ if completionTokenDetails.RejectedPredictionTokens > 0 {
+ streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
+ }
+ }
+ }
+ if len(chunk.Choices) == 0 {
+ continue
+ }
+ for _, choice := range chunk.Choices {
+ switch {
+ case choice.Delta.Content != "":
+ if !isActiveText {
+ isActiveText = true
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeTextStart,
+ ID: "0",
+ }) {
+ return
+ }
+ }
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeTextDelta,
+ ID: "0",
+ Delta: choice.Delta.Content,
+ }) {
+ return
+ }
+ case len(choice.Delta.ToolCalls) > 0:
+ if isActiveText {
+ isActiveText = false
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeTextEnd,
+ ID: "0",
+ }) {
+ return
+ }
+ }
+
+ for _, toolCallDelta := range choice.Delta.ToolCalls {
+ if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
+ if existingToolCall.hasFinished {
+ continue
+ }
+ if toolCallDelta.Function.Arguments != "" {
+ existingToolCall.arguments += toolCallDelta.Function.Arguments
+ }
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolInputDelta,
+ ID: existingToolCall.id,
+ Delta: toolCallDelta.Function.Arguments,
+ }) {
+ return
+ }
+ toolCalls[toolCallDelta.Index] = existingToolCall
+ if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolInputEnd,
+ ID: existingToolCall.id,
+ }) {
+ return
+ }
+
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolCall,
+ ID: existingToolCall.id,
+ ToolCallName: existingToolCall.name,
+ ToolCallInput: existingToolCall.arguments,
+ }) {
+ return
+ }
+ existingToolCall.hasFinished = true
+ toolCalls[toolCallDelta.Index] = existingToolCall
+ }
+
+ } else {
+ // Does not exist
+ var err error
+ if toolCallDelta.Type != "function" {
+ err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
+ }
+ if toolCallDelta.ID == "" {
+ err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
+ }
+ if toolCallDelta.Function.Name == "" {
+ err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
+ }
+ if err != nil {
+ yield(ai.StreamPart{
+ Type: ai.StreamPartTypeError,
+ Error: stream.Err(),
+ })
+ return
+ }
+
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolInputStart,
+ ID: toolCallDelta.ID,
+ ToolCallName: toolCallDelta.Function.Name,
+ }) {
+ return
+ }
+ toolCalls[toolCallDelta.Index] = toolCall{
+ id: toolCallDelta.ID,
+ name: toolCallDelta.Function.Name,
+ arguments: toolCallDelta.Function.Arguments,
+ }
+
+ exTc := toolCalls[toolCallDelta.Index]
+ if exTc.arguments != "" {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolInputDelta,
+ ID: exTc.id,
+ Delta: exTc.arguments,
+ }) {
+ return
+ }
+ if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolInputEnd,
+ ID: toolCallDelta.ID,
+ }) {
+ return
+ }
+
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeToolCall,
+ ID: exTc.id,
+ ToolCallName: exTc.name,
+ ToolCallInput: exTc.arguments,
+ }) {
+ return
+ }
+ exTc.hasFinished = true
+ toolCalls[toolCallDelta.Index] = exTc
+ }
+ }
+ continue
+ }
+ }
+ }
+ }
+
+ // Check for annotations in the delta's raw JSON
+ for _, choice := range chunk.Choices {
+ if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
+ for _, annotation := range annotations {
+ if annotation.Type == "url_citation" {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeSource,
+ ID: uuid.NewString(),
+ SourceType: ai.SourceTypeURL,
+ URL: annotation.URLCitation.URL,
+ Title: annotation.URLCitation.Title,
+ }) {
+ return
+ }
+ }
+ }
+ }
+ }
+
+ }
+ err := stream.Err()
+ if err == nil || errors.Is(err, io.EOF) {
+ // finished
+ if isActiveText {
+ isActiveText = false
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeTextEnd,
+ ID: "0",
+ }) {
+ return
+ }
+ }
+
+ // Add logprobs if available
+ if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
+ streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
+ }
+
+ // Handle annotations/citations from accumulated response
+ if len(acc.Choices) > 0 {
+ for _, annotation := range acc.Choices[0].Message.Annotations {
+ if annotation.Type == "url_citation" {
+ if !yield(ai.StreamPart{
+ Type: ai.StreamPartTypeSource,
+ ID: uuid.NewString(),
+ SourceType: ai.SourceTypeURL,
+ URL: annotation.URLCitation.URL,
+ Title: annotation.URLCitation.Title,
+ }) {
+ return
+ }
+ }
+ }
+ }
+
+ finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
+ yield(ai.StreamPart{
+ Type: ai.StreamPartTypeFinish,
+ Usage: usage,
+ FinishReason: finishReason,
+ ProviderMetadata: streamProviderMetadata,
+ })
+ return
+
+ } else {
+ yield(ai.StreamPart{
+ Type: ai.StreamPartTypeError,
+ Error: stream.Err(),
+ })
+ return
+ }
+ }, nil
+}
+
+func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
+ switch finishReason {
+ case "stop":
+ return ai.FinishReasonStop
+ case "length":
+ return ai.FinishReasonLength
+ case "content_filter":
+ return ai.FinishReasonContentFilter
+ case "function_call", "tool_calls":
+ return ai.FinishReasonToolCalls
+ default:
+ return ai.FinishReasonUnknown
+ }
+}
+
+func isReasoningModel(modelID string) bool {
+ return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
+}
+
+func isSearchPreviewModel(modelID string) bool {
+ return strings.Contains(modelID, "search-preview")
+}
+
+func supportsFlexProcessing(modelID string) bool {
+ return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
+}
+
+func supportsPriorityProcessing(modelID string) bool {
+ return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
+ strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
+ strings.HasPrefix(modelID, "o4-mini")
+}
+
+func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
+ for _, tool := range tools {
+ if tool.GetType() == ai.ToolTypeFunction {
+ ft, ok := tool.(ai.FunctionTool)
+ if !ok {
+ continue
+ }
+ openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
+ OfFunction: &openai.ChatCompletionFunctionToolParam{
+ Function: shared.FunctionDefinitionParam{
+ Name: ft.Name,
+ Description: param.NewOpt(ft.Description),
+ Parameters: openai.FunctionParameters(ft.InputSchema),
+ Strict: param.NewOpt(false),
+ },
+ Type: "function",
+ },
+ })
+ continue
+ }
+
+ // TODO: handle provider tool calls
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeUnsupportedTool,
+ Tool: tool,
+ Message: "tool is not supported",
+ })
+ }
+ if toolChoice == nil {
+ return
+ }
+
+ switch *toolChoice {
+ case ai.ToolChoiceAuto:
+ openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+ OfAuto: param.NewOpt("auto"),
+ }
+ case ai.ToolChoiceNone:
+ openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+ OfAuto: param.NewOpt("none"),
+ }
+ default:
+ openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
+ OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
+ Type: "function",
+ Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
+ Name: string(*toolChoice),
+ },
+ },
+ }
+ }
+ return
+}
+
+func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
+ var messages []openai.ChatCompletionMessageParamUnion
+ var warnings []ai.CallWarning
+ for _, msg := range prompt {
+ switch msg.Role {
+ case ai.MessageRoleSystem:
+ var systemPromptParts []string
+ for _, c := range msg.Content {
+ if c.GetType() != ai.ContentTypeText {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "system prompt can only have text content",
+ })
+ continue
+ }
+ textPart, ok := ai.AsContentType[ai.TextPart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "system prompt text part does not have the right type",
+ })
+ continue
+ }
+ text := textPart.Text
+ if strings.TrimSpace(text) != "" {
+ systemPromptParts = append(systemPromptParts, textPart.Text)
+ }
+ }
+ if len(systemPromptParts) == 0 {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "system prompt has no text parts",
+ })
+ continue
+ }
+ messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
+ case ai.MessageRoleUser:
+ // simple user message just text content
+ if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
+ textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "user message text part does not have the right type",
+ })
+ continue
+ }
+ messages = append(messages, openai.UserMessage(textPart.Text))
+ continue
+ }
+ // text content and attachments
+ // for now we only support image content later we need to check
+ // TODO: add the supported media types to the language model so we
+ // can use that to validate the data here.
+ var content []openai.ChatCompletionContentPartUnionParam
+ for _, c := range msg.Content {
+ switch c.GetType() {
+ case ai.ContentTypeText:
+ textPart, ok := ai.AsContentType[ai.TextPart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "user message text part does not have the right type",
+ })
+ continue
+ }
+ content = append(content, openai.ChatCompletionContentPartUnionParam{
+ OfText: &openai.ChatCompletionContentPartTextParam{
+ Text: textPart.Text,
+ },
+ })
+ case ai.ContentTypeFile:
+ filePart, ok := ai.AsContentType[ai.FilePart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "user message file part does not have the right type",
+ })
+ continue
+ }
+
+ switch {
+ case strings.HasPrefix(filePart.MediaType, "image/"):
+ // Handle image files
+ base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+ data := "data:" + filePart.MediaType + ";base64," + base64Encoded
+ imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
+
+ // Check for provider-specific options like image detail
+ if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
+ if detail, ok := providerOptions["imageDetail"].(string); ok {
+ imageURL.Detail = detail
+ }
+ }
+
+ imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
+
+ case filePart.MediaType == "audio/wav":
+ // Handle WAV audio files
+ base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+ audioBlock := openai.ChatCompletionContentPartInputAudioParam{
+ InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
+ Data: base64Encoded,
+ Format: "wav",
+ },
+ }
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
+
+ case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
+ // Handle MP3 audio files
+ base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+ audioBlock := openai.ChatCompletionContentPartInputAudioParam{
+ InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
+ Data: base64Encoded,
+ Format: "mp3",
+ },
+ }
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
+
+ case filePart.MediaType == "application/pdf":
+ // Handle PDF files
+ dataStr := string(filePart.Data)
+
+ // Check if data looks like a file ID (starts with "file-")
+ if strings.HasPrefix(dataStr, "file-") {
+ fileBlock := openai.ChatCompletionContentPartFileParam{
+ File: openai.ChatCompletionContentPartFileFileParam{
+ FileID: param.NewOpt(dataStr),
+ },
+ }
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
+ } else {
+ // Handle as base64 data
+ base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
+ data := "data:application/pdf;base64," + base64Encoded
+
+ filename := filePart.Filename
+ if filename == "" {
+ // Generate default filename based on content index
+ filename = fmt.Sprintf("part-%d.pdf", len(content))
+ }
+
+ fileBlock := openai.ChatCompletionContentPartFileParam{
+ File: openai.ChatCompletionContentPartFileFileParam{
+ Filename: param.NewOpt(filename),
+ FileData: param.NewOpt(data),
+ },
+ }
+ content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
+ }
+
+ default:
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
+ })
+ }
+ }
+ }
+ messages = append(messages, openai.UserMessage(content))
+ case ai.MessageRoleAssistant:
+ // simple assistant message just text content
+ if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
+ textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "assistant message text part does not have the right type",
+ })
+ continue
+ }
+ messages = append(messages, openai.AssistantMessage(textPart.Text))
+ continue
+ }
+ assistantMsg := openai.ChatCompletionAssistantMessageParam{
+ Role: "assistant",
+ }
+ for _, c := range msg.Content {
+ switch c.GetType() {
+ case ai.ContentTypeText:
+ textPart, ok := ai.AsContentType[ai.TextPart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "assistant message text part does not have the right type",
+ })
+ continue
+ }
+ assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
+ OfString: param.NewOpt(textPart.Text),
+ }
+ case ai.ContentTypeToolCall:
+ toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "assistant message tool part does not have the right type",
+ })
+ continue
+ }
+ assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
+ openai.ChatCompletionMessageToolCallUnionParam{
+ OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
+ ID: toolCallPart.ToolCallID,
+ Type: "function",
+ Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
+ Name: toolCallPart.ToolName,
+ Arguments: toolCallPart.Input,
+ },
+ },
+ })
+ }
+ }
+ messages = append(messages, openai.ChatCompletionMessageParamUnion{
+ OfAssistant: &assistantMsg,
+ })
+ case ai.MessageRoleTool:
+ for _, c := range msg.Content {
+ if c.GetType() != ai.ContentTypeToolResult {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "tool message can only have tool result content",
+ })
+ continue
+ }
+
+ toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "tool message result part does not have the right type",
+ })
+ continue
+ }
+
+ switch toolResultPart.Output.GetType() {
+ case ai.ToolResultContentTypeText:
+ output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "tool result output does not have the right type",
+ })
+ continue
+ }
+ messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
+ case ai.ToolResultContentTypeError:
+ // TODO: check if better handling is needed
+ output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
+ if !ok {
+ warnings = append(warnings, ai.CallWarning{
+ Type: ai.CallWarningTypeOther,
+ Message: "tool result output does not have the right type",
+ })
+ continue
+ }
+ messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
+ }
+ }
+ }
+ }
+ return messages, warnings
+}
+
+// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
+func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
+ var annotations []openai.ChatCompletionMessageAnnotation
+
+ // Parse the raw JSON to extract annotations
+ var deltaData map[string]interface{}
+ if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
+ return annotations
+ }
+
+ // Check if annotations exist in the delta
+ if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
+ for _, annotationData := range annotationsData {
+ if annotationMap, ok := annotationData.(map[string]interface{}); ok {
+ if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
+ if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
+ annotation := openai.ChatCompletionMessageAnnotation{
+ Type: "url_citation",
+ URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
+ URL: urlCitationData["url"].(string),
+ Title: urlCitationData["title"].(string),
+ },
+ }
+ annotations = append(annotations, annotation)
+ }
+ }
+ }
+ }
+ }
+
+ return annotations
+}
@@ -0,0 +1,2850 @@
+package providers
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/openai/openai-go/v2/packages/param"
+ "github.com/stretchr/testify/require"
+)
+
+func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should forward system messages", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleSystem,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "You are a helpful assistant."},
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ systemMsg := messages[0].OfSystem
+ require.NotNil(t, systemMsg)
+ require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
+ })
+
+ t.Run("should handle empty system messages", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleSystem,
+ Content: []ai.MessagePart{},
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Len(t, warnings, 1)
+ require.Contains(t, warnings[0].Message, "system prompt has no text parts")
+ require.Empty(t, messages)
+ })
+
+ t.Run("should join multiple system text parts", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleSystem,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "You are a helpful assistant."},
+ ai.TextPart{Text: "Be concise."},
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ systemMsg := messages[0].OfSystem
+ require.NotNil(t, systemMsg)
+ require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
+ })
+}
+
+func TestToOpenAIPrompt_UserMessages(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "Hello"},
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ require.NotNil(t, userMsg)
+ require.Equal(t, "Hello", userMsg.Content.OfString.Value)
+ })
+
+ t.Run("should convert messages with image parts", func(t *testing.T) {
+ t.Parallel()
+
+ imageData := []byte{0, 1, 2, 3}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "Hello"},
+ ai.FilePart{
+ MediaType: "image/png",
+ Data: imageData,
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ require.NotNil(t, userMsg)
+
+ content := userMsg.Content.OfArrayOfContentParts
+ require.Len(t, content, 2)
+
+ // Check text part
+ textPart := content[0].OfText
+ require.NotNil(t, textPart)
+ require.Equal(t, "Hello", textPart.Text)
+
+ // Check image part
+ imagePart := content[1].OfImageURL
+ require.NotNil(t, imagePart)
+ expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
+ require.Equal(t, expectedURL, imagePart.ImageURL.URL)
+ })
+
+ t.Run("should add image detail when specified through provider options", func(t *testing.T) {
+ t.Parallel()
+
+ imageData := []byte{0, 1, 2, 3}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "image/png",
+ Data: imageData,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "imageDetail": "low",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ require.NotNil(t, userMsg)
+
+ content := userMsg.Content.OfArrayOfContentParts
+ require.Len(t, content, 1)
+
+ imagePart := content[0].OfImageURL
+ require.NotNil(t, imagePart)
+ require.Equal(t, "low", imagePart.ImageURL.Detail)
+ })
+}
+
+func TestToOpenAIPrompt_FileParts(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should throw for unsupported mime types", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "application/something",
+ Data: []byte("test"),
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Len(t, warnings, 1)
+ require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
+ require.Len(t, messages, 1) // Message is still created but with empty content array
+ })
+
+ t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
+ t.Parallel()
+
+ audioData := []byte{0, 1, 2, 3}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "audio/wav",
+ Data: audioData,
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ require.NotNil(t, userMsg)
+
+ content := userMsg.Content.OfArrayOfContentParts
+ require.Len(t, content, 1)
+
+ audioPart := content[0].OfInputAudio
+ require.NotNil(t, audioPart)
+ require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
+ require.Equal(t, "wav", audioPart.InputAudio.Format)
+ })
+
+ t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
+ t.Parallel()
+
+ audioData := []byte{0, 1, 2, 3}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "audio/mpeg",
+ Data: audioData,
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ audioPart := content[0].OfInputAudio
+ require.NotNil(t, audioPart)
+ require.Equal(t, "mp3", audioPart.InputAudio.Format)
+ })
+
+ t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
+ t.Parallel()
+
+ audioData := []byte{0, 1, 2, 3}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "audio/mp3",
+ Data: audioData,
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ audioPart := content[0].OfInputAudio
+ require.NotNil(t, audioPart)
+ require.Equal(t, "mp3", audioPart.InputAudio.Format)
+ })
+
+ t.Run("should convert messages with PDF file parts", func(t *testing.T) {
+ t.Parallel()
+
+ pdfData := []byte{1, 2, 3, 4, 5}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "application/pdf",
+ Data: pdfData,
+ Filename: "document.pdf",
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ require.Len(t, content, 1)
+
+ filePart := content[0].OfFile
+ require.NotNil(t, filePart)
+ require.Equal(t, "document.pdf", filePart.File.Filename.Value)
+
+ expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
+ require.Equal(t, expectedData, filePart.File.FileData.Value)
+ })
+
+ t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
+ t.Parallel()
+
+ pdfData := []byte{1, 2, 3, 4, 5}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "application/pdf",
+ Data: pdfData,
+ Filename: "document.pdf",
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ filePart := content[0].OfFile
+ require.NotNil(t, filePart)
+
+ expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
+ require.Equal(t, expectedData, filePart.File.FileData.Value)
+ })
+
+ t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "application/pdf",
+ Data: []byte("file-pdf-12345"),
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ filePart := content[0].OfFile
+ require.NotNil(t, filePart)
+ require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
+ require.True(t, param.IsOmitted(filePart.File.FileData))
+ require.True(t, param.IsOmitted(filePart.File.Filename))
+ })
+
+ t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
+ t.Parallel()
+
+ pdfData := []byte{1, 2, 3, 4, 5}
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.FilePart{
+ MediaType: "application/pdf",
+ Data: pdfData,
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ userMsg := messages[0].OfUser
+ content := userMsg.Content.OfArrayOfContentParts
+ filePart := content[0].OfFile
+ require.NotNil(t, filePart)
+ require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
+ })
+}
+
+func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should stringify arguments to tool calls", func(t *testing.T) {
+ t.Parallel()
+
+ inputArgs := map[string]any{"foo": "bar123"}
+ inputJSON, _ := json.Marshal(inputArgs)
+
+ outputResult := map[string]any{"oof": "321rab"}
+ outputJSON, _ := json.Marshal(outputResult)
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleAssistant,
+ Content: []ai.MessagePart{
+ ai.ToolCallPart{
+ ToolCallID: "quux",
+ ToolName: "thwomp",
+ Input: string(inputJSON),
+ },
+ },
+ },
+ {
+ Role: ai.MessageRoleTool,
+ Content: []ai.MessagePart{
+ ai.ToolResultPart{
+ ToolCallID: "quux",
+ Output: ai.ToolResultOutputContentText{
+ Text: string(outputJSON),
+ },
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 2)
+
+ // Check assistant message with tool call
+ assistantMsg := messages[0].OfAssistant
+ require.NotNil(t, assistantMsg)
+ require.Equal(t, "", assistantMsg.Content.OfString.Value)
+ require.Len(t, assistantMsg.ToolCalls, 1)
+
+ toolCall := assistantMsg.ToolCalls[0].OfFunction
+ require.NotNil(t, toolCall)
+ require.Equal(t, "quux", toolCall.ID)
+ require.Equal(t, "thwomp", toolCall.Function.Name)
+ require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
+
+ // Check tool message
+ toolMsg := messages[1].OfTool
+ require.NotNil(t, toolMsg)
+ require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
+ require.Equal(t, "quux", toolMsg.ToolCallID)
+ })
+
+ t.Run("should handle different tool output types", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleTool,
+ Content: []ai.MessagePart{
+ ai.ToolResultPart{
+ ToolCallID: "text-tool",
+ Output: ai.ToolResultOutputContentText{
+ Text: "Hello world",
+ },
+ },
+ ai.ToolResultPart{
+ ToolCallID: "error-tool",
+ Output: ai.ToolResultOutputContentError{
+ Error: "Something went wrong",
+ },
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 2)
+
+ // Check first tool message (text)
+ textToolMsg := messages[0].OfTool
+ require.NotNil(t, textToolMsg)
+ require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
+ require.Equal(t, "text-tool", textToolMsg.ToolCallID)
+
+ // Check second tool message (error)
+ errorToolMsg := messages[1].OfTool
+ require.NotNil(t, errorToolMsg)
+ require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
+ require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
+ })
+}
+
+func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should handle simple text assistant messages", func(t *testing.T) {
+ t.Parallel()
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleAssistant,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "Hello, how can I help you?"},
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ assistantMsg := messages[0].OfAssistant
+ require.NotNil(t, assistantMsg)
+ require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
+ })
+
+ t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
+ t.Parallel()
+
+ inputArgs := map[string]any{"query": "test"}
+ inputJSON, _ := json.Marshal(inputArgs)
+
+ prompt := ai.Prompt{
+ {
+ Role: ai.MessageRoleAssistant,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "Let me search for that."},
+ ai.ToolCallPart{
+ ToolCallID: "call-123",
+ ToolName: "search",
+ Input: string(inputJSON),
+ },
+ },
+ },
+ }
+
+ messages, warnings := toOpenAIPrompt(prompt)
+
+ require.Empty(t, warnings)
+ require.Len(t, messages, 1)
+
+ assistantMsg := messages[0].OfAssistant
+ require.NotNil(t, assistantMsg)
+ require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
+ require.Len(t, assistantMsg.ToolCalls, 1)
+
+ toolCall := assistantMsg.ToolCalls[0].OfFunction
+ require.Equal(t, "call-123", toolCall.ID)
+ require.Equal(t, "search", toolCall.Function.Name)
+ require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
+ })
+}
+
+var testPrompt = ai.Prompt{
+ {
+ Role: ai.MessageRoleUser,
+ Content: []ai.MessagePart{
+ ai.TextPart{Text: "Hello"},
+ },
+ },
+}
+
+var testLogprobs = map[string]any{
+ "content": []map[string]any{
+ {
+ "token": "Hello",
+ "logprob": -0.0009994634,
+ "top_logprobs": []map[string]any{
+ {
+ "token": "Hello",
+ "logprob": -0.0009994634,
+ },
+ },
+ },
+ {
+ "token": "!",
+ "logprob": -0.13410144,
+ "top_logprobs": []map[string]any{
+ {
+ "token": "!",
+ "logprob": -0.13410144,
+ },
+ },
+ },
+ {
+ "token": " How",
+ "logprob": -0.0009250381,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " How",
+ "logprob": -0.0009250381,
+ },
+ },
+ },
+ {
+ "token": " can",
+ "logprob": -0.047709424,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " can",
+ "logprob": -0.047709424,
+ },
+ },
+ },
+ {
+ "token": " I",
+ "logprob": -0.000009014684,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " I",
+ "logprob": -0.000009014684,
+ },
+ },
+ },
+ {
+ "token": " assist",
+ "logprob": -0.009125131,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " assist",
+ "logprob": -0.009125131,
+ },
+ },
+ },
+ {
+ "token": " you",
+ "logprob": -0.0000066306106,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " you",
+ "logprob": -0.0000066306106,
+ },
+ },
+ },
+ {
+ "token": " today",
+ "logprob": -0.00011093382,
+ "top_logprobs": []map[string]any{
+ {
+ "token": " today",
+ "logprob": -0.00011093382,
+ },
+ },
+ },
+ {
+ "token": "?",
+ "logprob": -0.00004596782,
+ "top_logprobs": []map[string]any{
+ {
+ "token": "?",
+ "logprob": -0.00004596782,
+ },
+ },
+ },
+ },
+}
+
+type mockServer struct {
+ server *httptest.Server
+ response map[string]any
+ calls []mockCall
+}
+
+type mockCall struct {
+ method string
+ path string
+ headers map[string]string
+ body map[string]any
+}
+
+func newMockServer() *mockServer {
+ ms := &mockServer{
+ calls: make([]mockCall, 0),
+ }
+
+ ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Record the call
+ call := mockCall{
+ method: r.Method,
+ path: r.URL.Path,
+ headers: make(map[string]string),
+ }
+
+ for k, v := range r.Header {
+ if len(v) > 0 {
+ call.headers[k] = v[0]
+ }
+ }
+
+ // Parse request body
+ if r.Body != nil {
+ var body map[string]any
+ json.NewDecoder(r.Body).Decode(&body)
+ call.body = body
+ }
+
+ ms.calls = append(ms.calls, call)
+
+ // Return mock response
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(ms.response)
+ }))
+
+ return ms
+}
+
+func (ms *mockServer) close() {
+ ms.server.Close()
+}
+
+func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
+ // Default values
+ response := map[string]any{
+ "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
+ "object": "chat.completion",
+ "created": 1711115037,
+ "model": "gpt-3.5-turbo-0125",
+ "choices": []map[string]any{
+ {
+ "index": 0,
+ "message": map[string]any{
+ "role": "assistant",
+ "content": "",
+ },
+ "finish_reason": "stop",
+ },
+ },
+ "usage": map[string]any{
+ "prompt_tokens": 4,
+ "total_tokens": 34,
+ "completion_tokens": 30,
+ },
+ "system_fingerprint": "fp_3bc1b5746c",
+ }
+
+ // Override with provided options
+ for k, v := range opts {
+ switch k {
+ case "content":
+ response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
+ case "tool_calls":
+ response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
+ case "function_call":
+ response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
+ case "annotations":
+ response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
+ case "usage":
+ response["usage"] = v
+ case "finish_reason":
+ response["choices"].([]map[string]any)[0]["finish_reason"] = v
+ case "id":
+ response["id"] = v
+ case "created":
+ response["created"] = v
+ case "model":
+ response["model"] = v
+ case "logprobs":
+ if v != nil {
+ response["choices"].([]map[string]any)[0]["logprobs"] = v
+ }
+ }
+ }
+
+ ms.response = response
+}
+
+func TestDoGenerate(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should extract text response", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "Hello, World!",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Len(t, result.Content, 1)
+
+ textContent, ok := result.Content[0].(ai.TextContent)
+ require.True(t, ok)
+ require.Equal(t, "Hello, World!", textContent.Text)
+ })
+
+ t.Run("should extract usage", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "usage": map[string]any{
+ "prompt_tokens": 20,
+ "total_tokens": 25,
+ "completion_tokens": 5,
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, int64(20), result.Usage.InputTokens)
+ require.Equal(t, int64(5), result.Usage.OutputTokens)
+ require.Equal(t, int64(25), result.Usage.TotalTokens)
+ })
+
+ t.Run("should send request body", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "POST", call.method)
+ require.Equal(t, "/chat/completions", call.path)
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ messages, ok := call.body["messages"].([]any)
+ require.True(t, ok)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should support partial usage", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "usage": map[string]any{
+ "prompt_tokens": 20,
+ "total_tokens": 20,
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, int64(20), result.Usage.InputTokens)
+ require.Equal(t, int64(0), result.Usage.OutputTokens)
+ require.Equal(t, int64(20), result.Usage.TotalTokens)
+ })
+
+ t.Run("should extract logprobs", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "logprobs": testLogprobs,
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "logProbs": true,
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result.ProviderMetadata)
+
+ openaiMeta, ok := result.ProviderMetadata["openai"]
+ require.True(t, ok)
+
+ logprobs, ok := openaiMeta["logprobs"]
+ require.True(t, ok)
+ require.NotNil(t, logprobs)
+ })
+
+ t.Run("should extract finish reason", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "finish_reason": "stop",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, ai.FinishReasonStop, result.FinishReason)
+ })
+
+ t.Run("should support unknown finish reason", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "finish_reason": "eos",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
+ })
+
+ t.Run("should pass the model and the messages", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should pass settings", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "logitBias": map[string]int64{
+ "50256": -100,
+ },
+ "parallelToolCalls": false,
+ "user": "test-user-id",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ logitBias := call.body["logit_bias"].(map[string]any)
+ require.Equal(t, float64(-100), logitBias["50256"])
+ require.Equal(t, false, call.body["parallel_tool_calls"])
+ require.Equal(t, "test-user-id", call.body["user"])
+ })
+
+ t.Run("should pass reasoningEffort setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-mini")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "reasoningEffort": "low",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o1-mini", call.body["model"])
+ require.Equal(t, "low", call.body["reasoning_effort"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should pass textVerbosity setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "textVerbosity": "low",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-4o", call.body["model"])
+ require.Equal(t, "low", call.body["verbosity"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should pass tools and toolChoice", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ Tools: []ai.Tool{
+ ai.FunctionTool{
+ Name: "test-tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "value": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"value"},
+ "additionalProperties": false,
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ },
+ },
+ },
+ ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ tools := call.body["tools"].([]any)
+ require.Len(t, tools, 1)
+
+ tool := tools[0].(map[string]any)
+ require.Equal(t, "function", tool["type"])
+
+ function := tool["function"].(map[string]any)
+ require.Equal(t, "test-tool", function["name"])
+ require.Equal(t, false, function["strict"])
+
+ toolChoice := call.body["tool_choice"].(map[string]any)
+ require.Equal(t, "function", toolChoice["type"])
+
+ toolChoiceFunction := toolChoice["function"].(map[string]any)
+ require.Equal(t, "test-tool", toolChoiceFunction["name"])
+ })
+
+ t.Run("should parse tool results", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "tool_calls": []map[string]any{
+ {
+ "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
+ "type": "function",
+ "function": map[string]any{
+ "name": "test-tool",
+ "arguments": `{"value":"Spark"}`,
+ },
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ Tools: []ai.Tool{
+ ai.FunctionTool{
+ Name: "test-tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "value": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"value"},
+ "additionalProperties": false,
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ },
+ },
+ },
+ ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, result.Content, 1)
+
+ toolCall, ok := result.Content[0].(ai.ToolCallContent)
+ require.True(t, ok)
+ require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
+ require.Equal(t, "test-tool", toolCall.ToolName)
+ require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
+ })
+
+ t.Run("should parse annotations/citations", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "Based on the search results [doc1], I found information.",
+ "annotations": []map[string]any{
+ {
+ "type": "url_citation",
+ "url_citation": map[string]any{
+ "start_index": 24,
+ "end_index": 29,
+ "url": "https://example.com/doc1.pdf",
+ "title": "Document 1",
+ },
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Len(t, result.Content, 2)
+
+ textContent, ok := result.Content[0].(ai.TextContent)
+ require.True(t, ok)
+ require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
+
+ sourceContent, ok := result.Content[1].(ai.SourceContent)
+ require.True(t, ok)
+ require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
+ require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
+ require.Equal(t, "Document 1", sourceContent.Title)
+ require.NotEmpty(t, sourceContent.ID)
+ })
+
+ t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "prompt_tokens_details": map[string]any{
+ "cached_tokens": 1152,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
+ require.Equal(t, int64(15), result.Usage.InputTokens)
+ require.Equal(t, int64(20), result.Usage.OutputTokens)
+ require.Equal(t, int64(35), result.Usage.TotalTokens)
+ })
+
+ t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "completion_tokens_details": map[string]any{
+ "accepted_prediction_tokens": 123,
+ "rejected_prediction_tokens": 456,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result.ProviderMetadata)
+
+ openaiMeta, ok := result.ProviderMetadata["openai"]
+ require.True(t, ok)
+ require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
+ require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+ })
+
+ t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ Temperature: &[]float64{0.5}[0],
+ TopP: &[]float64{0.7}[0],
+ FrequencyPenalty: &[]float64{0.2}[0],
+ PresencePenalty: &[]float64{0.3}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o1-preview", call.body["model"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+
+ // These should not be present
+ require.Nil(t, call.body["temperature"])
+ require.Nil(t, call.body["top_p"])
+ require.Nil(t, call.body["frequency_penalty"])
+ require.Nil(t, call.body["presence_penalty"])
+
+ // Should have warnings
+ require.Len(t, result.Warnings, 4)
+ require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+ require.Equal(t, "temperature", result.Warnings[0].Setting)
+ require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
+ })
+
+ t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ MaxOutputTokens: &[]int64{1000}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o1-preview", call.body["model"])
+ require.Equal(t, float64(1000), call.body["max_completion_tokens"])
+ require.Nil(t, call.body["max_tokens"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should return reasoning tokens", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "completion_tokens_details": map[string]any{
+ "reasoning_tokens": 10,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, int64(15), result.Usage.InputTokens)
+ require.Equal(t, int64(20), result.Usage.OutputTokens)
+ require.Equal(t, int64(35), result.Usage.TotalTokens)
+ require.Equal(t, int64(10), result.Usage.ReasoningTokens)
+ })
+
+ t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "model": "o1-preview",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "maxCompletionTokens": 255,
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o1-preview", call.body["model"])
+ require.Equal(t, float64(255), call.body["max_completion_tokens"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send prediction extension setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "prediction": map[string]any{
+ "type": "content",
+ "content": "Hello, World!",
+ },
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ prediction := call.body["prediction"].(map[string]any)
+ require.Equal(t, "content", prediction["type"])
+ require.Equal(t, "Hello, World!", prediction["content"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send store extension setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "store": true,
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, true, call.body["store"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send metadata extension values", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "metadata": map[string]any{
+ "custom": "value",
+ },
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+
+ metadata := call.body["metadata"].(map[string]any)
+ require.Equal(t, "value", metadata["custom"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send promptCacheKey extension value", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "promptCacheKey": "test-cache-key-123",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "safetyIdentifier": "test-safety-identifier-123",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-search-preview")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ Temperature: &[]float64{0.7}[0],
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-4o-search-preview", call.body["model"])
+ require.Nil(t, call.body["temperature"])
+
+ require.Len(t, result.Warnings, 1)
+ require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+ require.Equal(t, "temperature", result.Warnings[0].Setting)
+ require.Contains(t, result.Warnings[0].Details, "search preview models")
+ })
+
+ t.Run("should send serviceTier flex processing setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{
+ "content": "",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o3-mini")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "flex",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o3-mini", call.body["model"])
+ require.Equal(t, "flex", call.body["service_tier"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "flex",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Nil(t, call.body["service_tier"])
+
+ require.Len(t, result.Warnings, 1)
+ require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+ require.Equal(t, "serviceTier", result.Warnings[0].Setting)
+ require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
+ })
+
+ t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ _, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "priority",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-4o-mini", call.body["model"])
+ require.Equal(t, "priority", call.body["service_tier"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
+ t.Parallel()
+
+ server := newMockServer()
+ defer server.close()
+
+ server.prepareJSONResponse(map[string]any{})
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ result, err := model.Generate(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "priority",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Nil(t, call.body["service_tier"])
+
+ require.Len(t, result.Warnings, 1)
+ require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
+ require.Equal(t, "serviceTier", result.Warnings[0].Setting)
+ require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
+ })
+}
+
+type streamingMockServer struct {
+ server *httptest.Server
+ chunks []string
+ calls []mockCall
+}
+
+func newStreamingMockServer() *streamingMockServer {
+ sms := &streamingMockServer{
+ calls: make([]mockCall, 0),
+ }
+
+ sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Record the call
+ call := mockCall{
+ method: r.Method,
+ path: r.URL.Path,
+ headers: make(map[string]string),
+ }
+
+ for k, v := range r.Header {
+ if len(v) > 0 {
+ call.headers[k] = v[0]
+ }
+ }
+
+ // Parse request body
+ if r.Body != nil {
+ var body map[string]any
+ json.NewDecoder(r.Body).Decode(&body)
+ call.body = body
+ }
+
+ sms.calls = append(sms.calls, call)
+
+ // Set streaming headers
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+
+ // Add custom headers if any
+ for _, chunk := range sms.chunks {
+ if strings.HasPrefix(chunk, "HEADER:") {
+ parts := strings.SplitN(chunk[7:], ":", 2)
+ if len(parts) == 2 {
+ w.Header().Set(parts[0], parts[1])
+ }
+ continue
+ }
+ }
+
+ w.WriteHeader(http.StatusOK)
+
+ // Write chunks
+ for _, chunk := range sms.chunks {
+ if strings.HasPrefix(chunk, "HEADER:") {
+ continue
+ }
+ w.Write([]byte(chunk))
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ }
+ }))
+
+ return sms
+}
+
+func (sms *streamingMockServer) close() {
+ sms.server.Close()
+}
+
+func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
+ content := []string{}
+ if c, ok := opts["content"].([]string); ok {
+ content = c
+ }
+
+ usage := map[string]any{
+ "prompt_tokens": 17,
+ "total_tokens": 244,
+ "completion_tokens": 227,
+ }
+ if u, ok := opts["usage"].(map[string]any); ok {
+ usage = u
+ }
+
+ logprobs := map[string]any{}
+ if l, ok := opts["logprobs"].(map[string]any); ok {
+ logprobs = l
+ }
+
+ finishReason := "stop"
+ if fr, ok := opts["finish_reason"].(string); ok {
+ finishReason = fr
+ }
+
+ model := "gpt-3.5-turbo-0613"
+ if m, ok := opts["model"].(string); ok {
+ model = m
+ }
+
+ headers := map[string]string{}
+ if h, ok := opts["headers"].(map[string]string); ok {
+ headers = h
+ }
+
+ chunks := []string{}
+
+ // Add custom headers
+ for k, v := range headers {
+ chunks = append(chunks, "HEADER:"+k+":"+v)
+ }
+
+ // Initial chunk with role
+ initialChunk := map[string]any{
+ "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+ "object": "chat.completion.chunk",
+ "created": 1702657020,
+ "model": model,
+ "system_fingerprint": nil,
+ "choices": []map[string]any{
+ {
+ "index": 0,
+ "delta": map[string]any{
+ "role": "assistant",
+ "content": "",
+ },
+ "finish_reason": nil,
+ },
+ },
+ }
+ initialData, _ := json.Marshal(initialChunk)
+ chunks = append(chunks, "data: "+string(initialData)+"\n\n")
+
+ // Content chunks
+ for i, text := range content {
+ contentChunk := map[string]any{
+ "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+ "object": "chat.completion.chunk",
+ "created": 1702657020,
+ "model": model,
+ "system_fingerprint": nil,
+ "choices": []map[string]any{
+ {
+ "index": 1,
+ "delta": map[string]any{
+ "content": text,
+ },
+ "finish_reason": nil,
+ },
+ },
+ }
+ contentData, _ := json.Marshal(contentChunk)
+ chunks = append(chunks, "data: "+string(contentData)+"\n\n")
+
+ // Add annotations if this is the last content chunk and we have annotations
+ if i == len(content)-1 {
+ if annotations, ok := opts["annotations"].([]map[string]any); ok {
+ annotationChunk := map[string]any{
+ "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+ "object": "chat.completion.chunk",
+ "created": 1702657020,
+ "model": model,
+ "system_fingerprint": nil,
+ "choices": []map[string]any{
+ {
+ "index": 1,
+ "delta": map[string]any{
+ "annotations": annotations,
+ },
+ "finish_reason": nil,
+ },
+ },
+ }
+ annotationData, _ := json.Marshal(annotationChunk)
+ chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
+ }
+ }
+ }
+
+ // Finish chunk
+ finishChunk := map[string]any{
+ "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+ "object": "chat.completion.chunk",
+ "created": 1702657020,
+ "model": model,
+ "system_fingerprint": nil,
+ "choices": []map[string]any{
+ {
+ "index": 0,
+ "delta": map[string]any{},
+ "finish_reason": finishReason,
+ },
+ },
+ }
+
+ if len(logprobs) > 0 {
+ finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
+ }
+
+ finishData, _ := json.Marshal(finishChunk)
+ chunks = append(chunks, "data: "+string(finishData)+"\n\n")
+
+ // Usage chunk
+ usageChunk := map[string]any{
+ "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
+ "object": "chat.completion.chunk",
+ "created": 1702657020,
+ "model": model,
+ "system_fingerprint": "fp_3bc1b5746c",
+ "choices": []map[string]any{},
+ "usage": usage,
+ }
+ usageData, _ := json.Marshal(usageChunk)
+ chunks = append(chunks, "data: "+string(usageData)+"\n\n")
+
+ // Done
+ chunks = append(chunks, "data: [DONE]\n\n")
+
+ sms.chunks = chunks
+}
+
+func (sms *streamingMockServer) prepareToolStreamResponse() {
+ chunks := []string{
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
+ `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
+ "data: [DONE]\n\n",
+ }
+ sms.chunks = chunks
+}
+
+func (sms *streamingMockServer) prepareErrorStreamResponse() {
+ chunks := []string{
+ `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
+ "data: [DONE]\n\n",
+ }
+ sms.chunks = chunks
+}
+
+func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
+ var parts []ai.StreamPart
+ for part := range stream {
+ parts = append(parts, part)
+ if part.Type == ai.StreamPartTypeError {
+ break
+ }
+ if part.Type == ai.StreamPartTypeFinish {
+ break
+ }
+ }
+ return parts, nil
+}
+
+func TestDoStream(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should stream text deltas", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{"Hello", ", ", "World!"},
+ "finish_reason": "stop",
+ "usage": map[string]any{
+ "prompt_tokens": 17,
+ "total_tokens": 244,
+ "completion_tokens": 227,
+ },
+ "logprobs": testLogprobs,
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Verify stream structure
+ require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
+
+ // Find text parts
+ var textStart, textEnd, finish int = -1, -1, -1
+ var deltas []string
+
+ for i, part := range parts {
+ switch part.Type {
+ case ai.StreamPartTypeTextStart:
+ textStart = i
+ case ai.StreamPartTypeTextDelta:
+ deltas = append(deltas, part.Delta)
+ case ai.StreamPartTypeTextEnd:
+ textEnd = i
+ case ai.StreamPartTypeFinish:
+ finish = i
+ }
+ }
+
+ require.NotEqual(t, -1, textStart)
+ require.NotEqual(t, -1, textEnd)
+ require.NotEqual(t, -1, finish)
+ require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
+
+ // Check finish part
+ finishPart := parts[finish]
+ require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
+ require.Equal(t, int64(17), finishPart.Usage.InputTokens)
+ require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
+ require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
+ })
+
+ t.Run("should stream tool deltas", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareToolStreamResponse()
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ Tools: []ai.Tool{
+ ai.FunctionTool{
+ Name: "test-tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "value": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"value"},
+ "additionalProperties": false,
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ },
+ },
+ },
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find tool-related parts
+ toolInputStart, toolInputEnd, toolCall := -1, -1, -1
+ var toolDeltas []string
+
+ for i, part := range parts {
+ switch part.Type {
+ case ai.StreamPartTypeToolInputStart:
+ toolInputStart = i
+ require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
+ require.Equal(t, "test-tool", part.ToolCallName)
+ case ai.StreamPartTypeToolInputDelta:
+ toolDeltas = append(toolDeltas, part.Delta)
+ case ai.StreamPartTypeToolInputEnd:
+ toolInputEnd = i
+ case ai.StreamPartTypeToolCall:
+ toolCall = i
+ require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
+ require.Equal(t, "test-tool", part.ToolCallName)
+ require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
+ }
+ }
+
+ require.NotEqual(t, -1, toolInputStart)
+ require.NotEqual(t, -1, toolInputEnd)
+ require.NotEqual(t, -1, toolCall)
+
+ // Verify tool deltas combine to form the complete input
+ fullInput := ""
+ for _, delta := range toolDeltas {
+ fullInput += delta
+ }
+ require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
+ })
+
+ t.Run("should stream annotations/citations", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{"Based on search results"},
+ "annotations": []map[string]any{
+ {
+ "type": "url_citation",
+ "url_citation": map[string]any{
+ "start_index": 24,
+ "end_index": 29,
+ "url": "https://example.com/doc1.pdf",
+ "title": "Document 1",
+ },
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find source part
+ var sourcePart *ai.StreamPart
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeSource {
+ sourcePart = &part
+ break
+ }
+ }
+
+ require.NotNil(t, sourcePart)
+ require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
+ require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
+ require.Equal(t, "Document 1", sourcePart.Title)
+ require.NotEmpty(t, sourcePart.ID)
+ })
+
+ t.Run("should handle error stream parts", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareErrorStreamResponse()
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Should have error and finish parts
+ require.True(t, len(parts) >= 1)
+
+ // Find error part
+ var errorPart *ai.StreamPart
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeError {
+ errorPart = &part
+ break
+ }
+ }
+
+ require.NotNil(t, errorPart)
+ require.NotNil(t, errorPart.Error)
+ })
+
+ t.Run("should send request body", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "POST", call.method)
+ require.Equal(t, "/chat/completions", call.path)
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, true, call.body["stream"])
+
+ streamOptions := call.body["stream_options"].(map[string]any)
+ require.Equal(t, true, streamOptions["include_usage"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "prompt_tokens_details": map[string]any{
+ "cached_tokens": 1152,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find finish part
+ var finishPart *ai.StreamPart
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeFinish {
+ finishPart = &part
+ break
+ }
+ }
+
+ require.NotNil(t, finishPart)
+ require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
+ require.Equal(t, int64(15), finishPart.Usage.InputTokens)
+ require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
+ require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
+ })
+
+ t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "completion_tokens_details": map[string]any{
+ "accepted_prediction_tokens": 123,
+ "rejected_prediction_tokens": 456,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find finish part
+ var finishPart *ai.StreamPart
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeFinish {
+ finishPart = &part
+ break
+ }
+ }
+
+ require.NotNil(t, finishPart)
+ require.NotNil(t, finishPart.ProviderMetadata)
+
+ openaiMeta, ok := finishPart.ProviderMetadata["openai"]
+ require.True(t, ok)
+ require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
+ require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
+ })
+
+ t.Run("should send store extension setting", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "store": true,
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, true, call.body["stream"])
+ require.Equal(t, true, call.body["store"])
+
+ streamOptions := call.body["stream_options"].(map[string]any)
+ require.Equal(t, true, streamOptions["include_usage"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send metadata extension values", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-3.5-turbo")
+
+ _, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "metadata": map[string]any{
+ "custom": "value",
+ },
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-3.5-turbo", call.body["model"])
+ require.Equal(t, true, call.body["stream"])
+
+ metadata := call.body["metadata"].(map[string]any)
+ require.Equal(t, "value", metadata["custom"])
+
+ streamOptions := call.body["stream_options"].(map[string]any)
+ require.Equal(t, true, streamOptions["include_usage"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o3-mini")
+
+ _, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "flex",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "o3-mini", call.body["model"])
+ require.Equal(t, "flex", call.body["service_tier"])
+ require.Equal(t, true, call.body["stream"])
+
+ streamOptions := call.body["stream_options"].(map[string]any)
+ require.Equal(t, true, streamOptions["include_usage"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{},
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ _, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ ProviderOptions: ai.ProviderOptions{
+ "openai": map[string]any{
+ "serviceTier": "priority",
+ },
+ },
+ })
+
+ require.NoError(t, err)
+ require.Len(t, server.calls, 1)
+
+ call := server.calls[0]
+ require.Equal(t, "gpt-4o-mini", call.body["model"])
+ require.Equal(t, "priority", call.body["service_tier"])
+ require.Equal(t, true, call.body["stream"])
+
+ streamOptions := call.body["stream_options"].(map[string]any)
+ require.Equal(t, true, streamOptions["include_usage"])
+
+ messages := call.body["messages"].([]any)
+ require.Len(t, messages, 1)
+
+ message := messages[0].(map[string]any)
+ require.Equal(t, "user", message["role"])
+ require.Equal(t, "Hello", message["content"])
+ })
+
+ t.Run("should stream text delta for reasoning models", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{"Hello, World!"},
+ "model": "o1-preview",
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find text parts
+ var textDeltas []string
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeTextDelta {
+ textDeltas = append(textDeltas, part.Delta)
+ }
+ }
+
+ // Should contain the text content (without empty delta)
+ require.Equal(t, []string{"Hello, World!"}, textDeltas)
+ })
+
+ t.Run("should send reasoning tokens", func(t *testing.T) {
+ t.Parallel()
+
+ server := newStreamingMockServer()
+ defer server.close()
+
+ server.prepareStreamResponse(map[string]any{
+ "content": []string{"Hello, World!"},
+ "model": "o1-preview",
+ "usage": map[string]any{
+ "prompt_tokens": 15,
+ "completion_tokens": 20,
+ "total_tokens": 35,
+ "completion_tokens_details": map[string]any{
+ "reasoning_tokens": 10,
+ },
+ },
+ })
+
+ provider := NewOpenAIProvider(
+ WithOpenAIApiKey("test-api-key"),
+ WithOpenAIBaseURL(server.server.URL),
+ )
+ model := provider.LanguageModel("o1-preview")
+
+ stream, err := model.Stream(context.Background(), ai.Call{
+ Prompt: testPrompt,
+ })
+
+ require.NoError(t, err)
+
+ parts, err := collectStreamParts(stream)
+ require.NoError(t, err)
+
+ // Find finish part
+ var finishPart *ai.StreamPart
+ for _, part := range parts {
+ if part.Type == ai.StreamPartTypeFinish {
+ finishPart = &part
+ break
+ }
+ }
+
+ require.NotNil(t, finishPart)
+ require.Equal(t, int64(15), finishPart.Usage.InputTokens)
+ require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
+ require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
+ require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
+ })
+}