1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/ai"
14 "github.com/google/uuid"
15 "github.com/openai/openai-go/v2"
16 "github.com/openai/openai-go/v2/option"
17 "github.com/openai/openai-go/v2/packages/param"
18 "github.com/openai/openai-go/v2/shared"
19)
20
21type ReasoningEffort string
22
23const (
24 ReasoningEffortMinimal ReasoningEffort = "minimal"
25 ReasoningEffortLow ReasoningEffort = "low"
26 ReasoningEffortMedium ReasoningEffort = "medium"
27 ReasoningEffortHigh ReasoningEffort = "high"
28)
29
30type OpenAIProviderOptions struct {
31 LogitBias map[string]int64 `json:"logit_bias"`
32 LogProbs *bool `json:"log_probes"`
33 TopLogProbs *int64 `json:"top_log_probs"`
34 ParallelToolCalls *bool `json:"parallel_tool_calls"`
35 User *string `json:"user"`
36 ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
37 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
38 TextVerbosity *string `json:"text_verbosity"`
39 Prediction map[string]any `json:"prediction"`
40 Store *bool `json:"store"`
41 Metadata map[string]any `json:"metadata"`
42 PromptCacheKey *string `json:"prompt_cache_key"`
43 SafetyIdentifier *string `json:"safety_identifier"`
44 ServiceTier *string `json:"service_tier"`
45 StructuredOutputs *bool `json:"structured_outputs"`
46}
47
48type openAIProvider struct {
49 options openAIProviderOptions
50}
51
52type openAIProviderOptions struct {
53 baseURL string
54 apiKey string
55 organization string
56 project string
57 name string
58 headers map[string]string
59 client option.HTTPClient
60}
61
62type OpenAIOption = func(*openAIProviderOptions)
63
64func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
65 options := openAIProviderOptions{
66 headers: map[string]string{},
67 }
68 for _, o := range opts {
69 o(&options)
70 }
71
72 if options.baseURL == "" {
73 options.baseURL = "https://api.openai.com/v1"
74 }
75
76 if options.name == "" {
77 options.name = "openai"
78 }
79
80 if options.organization != "" {
81 options.headers["OpenAI-Organization"] = options.organization
82 }
83
84 if options.project != "" {
85 options.headers["OpenAI-Project"] = options.project
86 }
87
88 return &openAIProvider{
89 options: options,
90 }
91}
92
93func WithOpenAIBaseURL(baseURL string) OpenAIOption {
94 return func(o *openAIProviderOptions) {
95 o.baseURL = baseURL
96 }
97}
98
99func WithOpenAIApiKey(apiKey string) OpenAIOption {
100 return func(o *openAIProviderOptions) {
101 o.apiKey = apiKey
102 }
103}
104
105func WithOpenAIOrganization(organization string) OpenAIOption {
106 return func(o *openAIProviderOptions) {
107 o.organization = organization
108 }
109}
110
111func WithOpenAIProject(project string) OpenAIOption {
112 return func(o *openAIProviderOptions) {
113 o.project = project
114 }
115}
116
117func WithOpenAIName(name string) OpenAIOption {
118 return func(o *openAIProviderOptions) {
119 o.name = name
120 }
121}
122
123func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
124 return func(o *openAIProviderOptions) {
125 maps.Copy(o.headers, headers)
126 }
127}
128
129func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
130 return func(o *openAIProviderOptions) {
131 o.client = client
132 }
133}
134
135// LanguageModel implements ai.Provider.
136func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
137 openaiClientOptions := []option.RequestOption{}
138 if o.options.apiKey != "" {
139 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
140 }
141 if o.options.baseURL != "" {
142 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
143 }
144
145 for key, value := range o.options.headers {
146 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
147 }
148
149 if o.options.client != nil {
150 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
151 }
152
153 return openAILanguageModel{
154 modelID: modelID,
155 provider: fmt.Sprintf("%s.chat", o.options.name),
156 providerOptions: o.options,
157 client: openai.NewClient(openaiClientOptions...),
158 }
159}
160
161type openAILanguageModel struct {
162 provider string
163 modelID string
164 client openai.Client
165 providerOptions openAIProviderOptions
166}
167
168// Model implements ai.LanguageModel.
169func (o openAILanguageModel) Model() string {
170 return o.modelID
171}
172
173// Provider implements ai.LanguageModel.
174func (o openAILanguageModel) Provider() string {
175 return o.provider
176}
177
178func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
179 params := &openai.ChatCompletionNewParams{}
180 messages, warnings := toOpenAIPrompt(call.Prompt)
181 providerOptions := &OpenAIProviderOptions{}
182 if v, ok := call.ProviderOptions["openai"]; ok {
183 err := ai.ParseOptions(v, providerOptions)
184 if err != nil {
185 return nil, nil, err
186 }
187 }
188 if call.TopK != nil {
189 warnings = append(warnings, ai.CallWarning{
190 Type: ai.CallWarningTypeUnsupportedSetting,
191 Setting: "top_k",
192 })
193 }
194 params.Messages = messages
195 params.Model = o.modelID
196 if providerOptions.LogitBias != nil {
197 params.LogitBias = providerOptions.LogitBias
198 }
199 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
200 providerOptions.LogProbs = nil
201 }
202 if providerOptions.LogProbs != nil {
203 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
204 }
205 if providerOptions.TopLogProbs != nil {
206 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
207 }
208 if providerOptions.User != nil {
209 params.User = param.NewOpt(*providerOptions.User)
210 }
211 if providerOptions.ParallelToolCalls != nil {
212 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
213 }
214
215 if call.MaxOutputTokens != nil {
216 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
217 }
218 if call.Temperature != nil {
219 params.Temperature = param.NewOpt(*call.Temperature)
220 }
221 if call.TopP != nil {
222 params.TopP = param.NewOpt(*call.TopP)
223 }
224 if call.FrequencyPenalty != nil {
225 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
226 }
227 if call.PresencePenalty != nil {
228 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
229 }
230
231 if providerOptions.MaxCompletionTokens != nil {
232 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
233 }
234
235 if providerOptions.TextVerbosity != nil {
236 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
237 }
238 if providerOptions.Prediction != nil {
239 // Convert map[string]any to ChatCompletionPredictionContentParam
240 if content, ok := providerOptions.Prediction["content"]; ok {
241 if contentStr, ok := content.(string); ok {
242 params.Prediction = openai.ChatCompletionPredictionContentParam{
243 Content: openai.ChatCompletionPredictionContentContentUnionParam{
244 OfString: param.NewOpt(contentStr),
245 },
246 }
247 }
248 }
249 }
250 if providerOptions.Store != nil {
251 params.Store = param.NewOpt(*providerOptions.Store)
252 }
253 if providerOptions.Metadata != nil {
254 // Convert map[string]any to map[string]string
255 metadata := make(map[string]string)
256 for k, v := range providerOptions.Metadata {
257 if str, ok := v.(string); ok {
258 metadata[k] = str
259 }
260 }
261 params.Metadata = metadata
262 }
263 if providerOptions.PromptCacheKey != nil {
264 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
265 }
266 if providerOptions.SafetyIdentifier != nil {
267 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
268 }
269 if providerOptions.ServiceTier != nil {
270 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
271 }
272
273 if providerOptions.ReasoningEffort != nil {
274 switch *providerOptions.ReasoningEffort {
275 case ReasoningEffortMinimal:
276 params.ReasoningEffort = shared.ReasoningEffortMinimal
277 case ReasoningEffortLow:
278 params.ReasoningEffort = shared.ReasoningEffortLow
279 case ReasoningEffortMedium:
280 params.ReasoningEffort = shared.ReasoningEffortMedium
281 case ReasoningEffortHigh:
282 params.ReasoningEffort = shared.ReasoningEffortHigh
283 default:
284 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
285 }
286 }
287
288 if isReasoningModel(o.modelID) {
289 // remove unsupported settings for reasoning models
290 // see https://platform.openai.com/docs/guides/reasoning#limitations
291 if call.Temperature != nil {
292 params.Temperature = param.Opt[float64]{}
293 warnings = append(warnings, ai.CallWarning{
294 Type: ai.CallWarningTypeUnsupportedSetting,
295 Setting: "temperature",
296 Details: "temperature is not supported for reasoning models",
297 })
298 }
299 if call.TopP != nil {
300 params.TopP = param.Opt[float64]{}
301 warnings = append(warnings, ai.CallWarning{
302 Type: ai.CallWarningTypeUnsupportedSetting,
303 Setting: "top_p",
304 Details: "topP is not supported for reasoning models",
305 })
306 }
307 if call.FrequencyPenalty != nil {
308 params.FrequencyPenalty = param.Opt[float64]{}
309 warnings = append(warnings, ai.CallWarning{
310 Type: ai.CallWarningTypeUnsupportedSetting,
311 Setting: "frequency_penalty",
312 Details: "frequencyPenalty is not supported for reasoning models",
313 })
314 }
315 if call.PresencePenalty != nil {
316 params.PresencePenalty = param.Opt[float64]{}
317 warnings = append(warnings, ai.CallWarning{
318 Type: ai.CallWarningTypeUnsupportedSetting,
319 Setting: "presence_penalty",
320 Details: "presencePenalty is not supported for reasoning models",
321 })
322 }
323 if providerOptions.LogitBias != nil {
324 params.LogitBias = nil
325 warnings = append(warnings, ai.CallWarning{
326 Type: ai.CallWarningTypeOther,
327 Message: "logitBias is not supported for reasoning models",
328 })
329 }
330 if providerOptions.LogProbs != nil {
331 params.Logprobs = param.Opt[bool]{}
332 warnings = append(warnings, ai.CallWarning{
333 Type: ai.CallWarningTypeOther,
334 Message: "logprobs is not supported for reasoning models",
335 })
336 }
337 if providerOptions.TopLogProbs != nil {
338 params.TopLogprobs = param.Opt[int64]{}
339 warnings = append(warnings, ai.CallWarning{
340 Type: ai.CallWarningTypeOther,
341 Message: "topLogprobs is not supported for reasoning models",
342 })
343 }
344
345 // reasoning models use max_completion_tokens instead of max_tokens
346 if call.MaxOutputTokens != nil {
347 if providerOptions.MaxCompletionTokens == nil {
348 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
349 }
350 params.MaxTokens = param.Opt[int64]{}
351 }
352 }
353
354 // Handle search preview models
355 if isSearchPreviewModel(o.modelID) {
356 if call.Temperature != nil {
357 params.Temperature = param.Opt[float64]{}
358 warnings = append(warnings, ai.CallWarning{
359 Type: ai.CallWarningTypeUnsupportedSetting,
360 Setting: "temperature",
361 Details: "temperature is not supported for the search preview models and has been removed.",
362 })
363 }
364 }
365
366 // Handle service tier validation
367 if providerOptions.ServiceTier != nil {
368 serviceTier := *providerOptions.ServiceTier
369 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
370 params.ServiceTier = ""
371 warnings = append(warnings, ai.CallWarning{
372 Type: ai.CallWarningTypeUnsupportedSetting,
373 Setting: "serviceTier",
374 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
375 })
376 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
377 params.ServiceTier = ""
378 warnings = append(warnings, ai.CallWarning{
379 Type: ai.CallWarningTypeUnsupportedSetting,
380 Setting: "serviceTier",
381 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
382 })
383 }
384 }
385
386 if len(call.Tools) > 0 {
387 tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
388 params.Tools = tools
389 if toolChoice != nil {
390 params.ToolChoice = *toolChoice
391 }
392 warnings = append(warnings, toolWarnings...)
393 }
394 return params, warnings, nil
395}
396
397func (o openAILanguageModel) handleError(err error) error {
398 var apiErr *openai.Error
399 if errors.As(err, &apiErr) {
400 requestDump := apiErr.DumpRequest(true)
401 responseDump := apiErr.DumpResponse(true)
402 headers := map[string]string{}
403 for k, h := range apiErr.Response.Header {
404 v := h[len(h)-1]
405 headers[strings.ToLower(k)] = v
406 }
407 return ai.NewAPICallError(
408 apiErr.Message,
409 apiErr.Request.URL.String(),
410 string(requestDump),
411 apiErr.StatusCode,
412 headers,
413 string(responseDump),
414 apiErr,
415 false,
416 )
417 }
418 return err
419}
420
421// Generate implements ai.LanguageModel.
422func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
423 params, warnings, err := o.prepareParams(call)
424 if err != nil {
425 return nil, err
426 }
427 response, err := o.client.Chat.Completions.New(ctx, *params)
428 if err != nil {
429 return nil, o.handleError(err)
430 }
431
432 if len(response.Choices) == 0 {
433 return nil, errors.New("no response generated")
434 }
435 choice := response.Choices[0]
436 var content []ai.Content
437 text := choice.Message.Content
438 if text != "" {
439 content = append(content, ai.TextContent{
440 Text: text,
441 })
442 }
443
444 for _, tc := range choice.Message.ToolCalls {
445 toolCallID := tc.ID
446 if toolCallID == "" {
447 toolCallID = uuid.NewString()
448 }
449 content = append(content, ai.ToolCallContent{
450 ProviderExecuted: false, // TODO: update when handling other tools
451 ToolCallID: toolCallID,
452 ToolName: tc.Function.Name,
453 Input: tc.Function.Arguments,
454 })
455 }
456 // Handle annotations/citations
457 for _, annotation := range choice.Message.Annotations {
458 if annotation.Type == "url_citation" {
459 content = append(content, ai.SourceContent{
460 SourceType: ai.SourceTypeURL,
461 ID: uuid.NewString(),
462 URL: annotation.URLCitation.URL,
463 Title: annotation.URLCitation.Title,
464 })
465 }
466 }
467
468 completionTokenDetails := response.Usage.CompletionTokensDetails
469 promptTokenDetails := response.Usage.PromptTokensDetails
470
471 // Build provider metadata
472 providerMetadata := ai.ProviderMetadata{
473 "openai": make(map[string]any),
474 }
475
476 // Add logprobs if available
477 if len(choice.Logprobs.Content) > 0 {
478 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
479 }
480
481 // Add prediction tokens if available
482 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
483 if completionTokenDetails.AcceptedPredictionTokens > 0 {
484 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
485 }
486 if completionTokenDetails.RejectedPredictionTokens > 0 {
487 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
488 }
489 }
490
491 return &ai.Response{
492 Content: content,
493 Usage: ai.Usage{
494 InputTokens: response.Usage.PromptTokens,
495 OutputTokens: response.Usage.CompletionTokens,
496 TotalTokens: response.Usage.TotalTokens,
497 ReasoningTokens: completionTokenDetails.ReasoningTokens,
498 CacheReadTokens: promptTokenDetails.CachedTokens,
499 },
500 FinishReason: mapOpenAIFinishReason(choice.FinishReason),
501 ProviderMetadata: providerMetadata,
502 Warnings: warnings,
503 }, nil
504}
505
506type toolCall struct {
507 id string
508 name string
509 arguments string
510 hasFinished bool
511}
512
513// Stream implements ai.LanguageModel.
514func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
515 params, warnings, err := o.prepareParams(call)
516 if err != nil {
517 return nil, err
518 }
519
520 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
521 IncludeUsage: openai.Bool(true),
522 }
523
524 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
525 isActiveText := false
526 toolCalls := make(map[int64]toolCall)
527
528 // Build provider metadata for streaming
529 streamProviderMetadata := ai.ProviderOptions{
530 "openai": make(map[string]any),
531 }
532
533 acc := openai.ChatCompletionAccumulator{}
534 var usage ai.Usage
535 return func(yield func(ai.StreamPart) bool) {
536 if len(warnings) > 0 {
537 if !yield(ai.StreamPart{
538 Type: ai.StreamPartTypeWarnings,
539 Warnings: warnings,
540 }) {
541 return
542 }
543 }
544 for stream.Next() {
545 chunk := stream.Current()
546 acc.AddChunk(chunk)
547 if chunk.Usage.TotalTokens > 0 {
548 // we do this here because the acc does not add prompt details
549 completionTokenDetails := chunk.Usage.CompletionTokensDetails
550 promptTokenDetails := chunk.Usage.PromptTokensDetails
551 usage = ai.Usage{
552 InputTokens: chunk.Usage.PromptTokens,
553 OutputTokens: chunk.Usage.CompletionTokens,
554 TotalTokens: chunk.Usage.TotalTokens,
555 ReasoningTokens: completionTokenDetails.ReasoningTokens,
556 CacheReadTokens: promptTokenDetails.CachedTokens,
557 }
558
559 // Add prediction tokens if available
560 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
561 if completionTokenDetails.AcceptedPredictionTokens > 0 {
562 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
563 }
564 if completionTokenDetails.RejectedPredictionTokens > 0 {
565 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
566 }
567 }
568 }
569 if len(chunk.Choices) == 0 {
570 continue
571 }
572 for _, choice := range chunk.Choices {
573 switch {
574 case choice.Delta.Content != "":
575 if !isActiveText {
576 isActiveText = true
577 if !yield(ai.StreamPart{
578 Type: ai.StreamPartTypeTextStart,
579 ID: "0",
580 }) {
581 return
582 }
583 }
584 if !yield(ai.StreamPart{
585 Type: ai.StreamPartTypeTextDelta,
586 ID: "0",
587 Delta: choice.Delta.Content,
588 }) {
589 return
590 }
591 case len(choice.Delta.ToolCalls) > 0:
592 if isActiveText {
593 isActiveText = false
594 if !yield(ai.StreamPart{
595 Type: ai.StreamPartTypeTextEnd,
596 ID: "0",
597 }) {
598 return
599 }
600 }
601
602 for _, toolCallDelta := range choice.Delta.ToolCalls {
603 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
604 if existingToolCall.hasFinished {
605 continue
606 }
607 if toolCallDelta.Function.Arguments != "" {
608 existingToolCall.arguments += toolCallDelta.Function.Arguments
609 }
610 if !yield(ai.StreamPart{
611 Type: ai.StreamPartTypeToolInputDelta,
612 ID: existingToolCall.id,
613 Delta: toolCallDelta.Function.Arguments,
614 }) {
615 return
616 }
617 toolCalls[toolCallDelta.Index] = existingToolCall
618 if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
619 if !yield(ai.StreamPart{
620 Type: ai.StreamPartTypeToolInputEnd,
621 ID: existingToolCall.id,
622 }) {
623 return
624 }
625
626 if !yield(ai.StreamPart{
627 Type: ai.StreamPartTypeToolCall,
628 ID: existingToolCall.id,
629 ToolCallName: existingToolCall.name,
630 ToolCallInput: existingToolCall.arguments,
631 }) {
632 return
633 }
634 existingToolCall.hasFinished = true
635 toolCalls[toolCallDelta.Index] = existingToolCall
636 }
637 } else {
638 // Does not exist
639 var err error
640 if toolCallDelta.Type != "function" {
641 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
642 }
643 if toolCallDelta.ID == "" {
644 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
645 }
646 if toolCallDelta.Function.Name == "" {
647 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
648 }
649 if err != nil {
650 yield(ai.StreamPart{
651 Type: ai.StreamPartTypeError,
652 Error: o.handleError(stream.Err()),
653 })
654 return
655 }
656
657 if !yield(ai.StreamPart{
658 Type: ai.StreamPartTypeToolInputStart,
659 ID: toolCallDelta.ID,
660 ToolCallName: toolCallDelta.Function.Name,
661 }) {
662 return
663 }
664 toolCalls[toolCallDelta.Index] = toolCall{
665 id: toolCallDelta.ID,
666 name: toolCallDelta.Function.Name,
667 arguments: toolCallDelta.Function.Arguments,
668 }
669
670 exTc := toolCalls[toolCallDelta.Index]
671 if exTc.arguments != "" {
672 if !yield(ai.StreamPart{
673 Type: ai.StreamPartTypeToolInputDelta,
674 ID: exTc.id,
675 Delta: exTc.arguments,
676 }) {
677 return
678 }
679 if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
680 if !yield(ai.StreamPart{
681 Type: ai.StreamPartTypeToolInputEnd,
682 ID: toolCallDelta.ID,
683 }) {
684 return
685 }
686
687 if !yield(ai.StreamPart{
688 Type: ai.StreamPartTypeToolCall,
689 ID: exTc.id,
690 ToolCallName: exTc.name,
691 ToolCallInput: exTc.arguments,
692 }) {
693 return
694 }
695 exTc.hasFinished = true
696 toolCalls[toolCallDelta.Index] = exTc
697 }
698 }
699 continue
700 }
701 }
702 }
703 }
704
705 // Check for annotations in the delta's raw JSON
706 for _, choice := range chunk.Choices {
707 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
708 for _, annotation := range annotations {
709 if annotation.Type == "url_citation" {
710 if !yield(ai.StreamPart{
711 Type: ai.StreamPartTypeSource,
712 ID: uuid.NewString(),
713 SourceType: ai.SourceTypeURL,
714 URL: annotation.URLCitation.URL,
715 Title: annotation.URLCitation.Title,
716 }) {
717 return
718 }
719 }
720 }
721 }
722 }
723 }
724 err := stream.Err()
725 if err == nil || errors.Is(err, io.EOF) {
726 // finished
727 if isActiveText {
728 isActiveText = false
729 if !yield(ai.StreamPart{
730 Type: ai.StreamPartTypeTextEnd,
731 ID: "0",
732 }) {
733 return
734 }
735 }
736
737 // Add logprobs if available
738 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
739 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
740 }
741
742 // Handle annotations/citations from accumulated response
743 if len(acc.Choices) > 0 {
744 for _, annotation := range acc.Choices[0].Message.Annotations {
745 if annotation.Type == "url_citation" {
746 if !yield(ai.StreamPart{
747 Type: ai.StreamPartTypeSource,
748 ID: uuid.NewString(),
749 SourceType: ai.SourceTypeURL,
750 URL: annotation.URLCitation.URL,
751 Title: annotation.URLCitation.Title,
752 }) {
753 return
754 }
755 }
756 }
757 }
758
759 finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
760 yield(ai.StreamPart{
761 Type: ai.StreamPartTypeFinish,
762 Usage: usage,
763 FinishReason: finishReason,
764 ProviderMetadata: streamProviderMetadata,
765 })
766 return
767 } else {
768 yield(ai.StreamPart{
769 Type: ai.StreamPartTypeError,
770 Error: stream.Err(),
771 })
772 return
773 }
774 }, nil
775}
776
777func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
778 switch finishReason {
779 case "stop":
780 return ai.FinishReasonStop
781 case "length":
782 return ai.FinishReasonLength
783 case "content_filter":
784 return ai.FinishReasonContentFilter
785 case "function_call", "tool_calls":
786 return ai.FinishReasonToolCalls
787 default:
788 return ai.FinishReasonUnknown
789 }
790}
791
792func isReasoningModel(modelID string) bool {
793 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
794}
795
796func isSearchPreviewModel(modelID string) bool {
797 return strings.Contains(modelID, "search-preview")
798}
799
800func supportsFlexProcessing(modelID string) bool {
801 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
802}
803
804func supportsPriorityProcessing(modelID string) bool {
805 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
806 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
807 strings.HasPrefix(modelID, "o4-mini")
808}
809
810func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
811 for _, tool := range tools {
812 if tool.GetType() == ai.ToolTypeFunction {
813 ft, ok := tool.(ai.FunctionTool)
814 if !ok {
815 continue
816 }
817 openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
818 OfFunction: &openai.ChatCompletionFunctionToolParam{
819 Function: shared.FunctionDefinitionParam{
820 Name: ft.Name,
821 Description: param.NewOpt(ft.Description),
822 Parameters: openai.FunctionParameters(ft.InputSchema),
823 Strict: param.NewOpt(false),
824 },
825 Type: "function",
826 },
827 })
828 continue
829 }
830
831 // TODO: handle provider tool calls
832 warnings = append(warnings, ai.CallWarning{
833 Type: ai.CallWarningTypeUnsupportedTool,
834 Tool: tool,
835 Message: "tool is not supported",
836 })
837 }
838 if toolChoice == nil {
839 return
840 }
841
842 switch *toolChoice {
843 case ai.ToolChoiceAuto:
844 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
845 OfAuto: param.NewOpt("auto"),
846 }
847 case ai.ToolChoiceNone:
848 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
849 OfAuto: param.NewOpt("none"),
850 }
851 default:
852 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
853 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
854 Type: "function",
855 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
856 Name: string(*toolChoice),
857 },
858 },
859 }
860 }
861 return
862}
863
864func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
865 var messages []openai.ChatCompletionMessageParamUnion
866 var warnings []ai.CallWarning
867 for _, msg := range prompt {
868 switch msg.Role {
869 case ai.MessageRoleSystem:
870 var systemPromptParts []string
871 for _, c := range msg.Content {
872 if c.GetType() != ai.ContentTypeText {
873 warnings = append(warnings, ai.CallWarning{
874 Type: ai.CallWarningTypeOther,
875 Message: "system prompt can only have text content",
876 })
877 continue
878 }
879 textPart, ok := ai.AsContentType[ai.TextPart](c)
880 if !ok {
881 warnings = append(warnings, ai.CallWarning{
882 Type: ai.CallWarningTypeOther,
883 Message: "system prompt text part does not have the right type",
884 })
885 continue
886 }
887 text := textPart.Text
888 if strings.TrimSpace(text) != "" {
889 systemPromptParts = append(systemPromptParts, textPart.Text)
890 }
891 }
892 if len(systemPromptParts) == 0 {
893 warnings = append(warnings, ai.CallWarning{
894 Type: ai.CallWarningTypeOther,
895 Message: "system prompt has no text parts",
896 })
897 continue
898 }
899 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
900 case ai.MessageRoleUser:
901 // simple user message just text content
902 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
903 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
904 if !ok {
905 warnings = append(warnings, ai.CallWarning{
906 Type: ai.CallWarningTypeOther,
907 Message: "user message text part does not have the right type",
908 })
909 continue
910 }
911 messages = append(messages, openai.UserMessage(textPart.Text))
912 continue
913 }
914 // text content and attachments
915 // for now we only support image content later we need to check
916 // TODO: add the supported media types to the language model so we
917 // can use that to validate the data here.
918 var content []openai.ChatCompletionContentPartUnionParam
919 for _, c := range msg.Content {
920 switch c.GetType() {
921 case ai.ContentTypeText:
922 textPart, ok := ai.AsContentType[ai.TextPart](c)
923 if !ok {
924 warnings = append(warnings, ai.CallWarning{
925 Type: ai.CallWarningTypeOther,
926 Message: "user message text part does not have the right type",
927 })
928 continue
929 }
930 content = append(content, openai.ChatCompletionContentPartUnionParam{
931 OfText: &openai.ChatCompletionContentPartTextParam{
932 Text: textPart.Text,
933 },
934 })
935 case ai.ContentTypeFile:
936 filePart, ok := ai.AsContentType[ai.FilePart](c)
937 if !ok {
938 warnings = append(warnings, ai.CallWarning{
939 Type: ai.CallWarningTypeOther,
940 Message: "user message file part does not have the right type",
941 })
942 continue
943 }
944
945 switch {
946 case strings.HasPrefix(filePart.MediaType, "image/"):
947 // Handle image files
948 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
949 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
950 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
951
952 // Check for provider-specific options like image detail
953 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
954 if detail, ok := providerOptions["imageDetail"].(string); ok {
955 imageURL.Detail = detail
956 }
957 }
958
959 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
960 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
961
962 case filePart.MediaType == "audio/wav":
963 // Handle WAV audio files
964 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
965 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
966 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
967 Data: base64Encoded,
968 Format: "wav",
969 },
970 }
971 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
972
973 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
974 // Handle MP3 audio files
975 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
976 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
977 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
978 Data: base64Encoded,
979 Format: "mp3",
980 },
981 }
982 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
983
984 case filePart.MediaType == "application/pdf":
985 // Handle PDF files
986 dataStr := string(filePart.Data)
987
988 // Check if data looks like a file ID (starts with "file-")
989 if strings.HasPrefix(dataStr, "file-") {
990 fileBlock := openai.ChatCompletionContentPartFileParam{
991 File: openai.ChatCompletionContentPartFileFileParam{
992 FileID: param.NewOpt(dataStr),
993 },
994 }
995 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
996 } else {
997 // Handle as base64 data
998 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
999 data := "data:application/pdf;base64," + base64Encoded
1000
1001 filename := filePart.Filename
1002 if filename == "" {
1003 // Generate default filename based on content index
1004 filename = fmt.Sprintf("part-%d.pdf", len(content))
1005 }
1006
1007 fileBlock := openai.ChatCompletionContentPartFileParam{
1008 File: openai.ChatCompletionContentPartFileFileParam{
1009 Filename: param.NewOpt(filename),
1010 FileData: param.NewOpt(data),
1011 },
1012 }
1013 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1014 }
1015
1016 default:
1017 warnings = append(warnings, ai.CallWarning{
1018 Type: ai.CallWarningTypeOther,
1019 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1020 })
1021 }
1022 }
1023 }
1024 messages = append(messages, openai.UserMessage(content))
1025 case ai.MessageRoleAssistant:
1026 // simple assistant message just text content
1027 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1028 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1029 if !ok {
1030 warnings = append(warnings, ai.CallWarning{
1031 Type: ai.CallWarningTypeOther,
1032 Message: "assistant message text part does not have the right type",
1033 })
1034 continue
1035 }
1036 messages = append(messages, openai.AssistantMessage(textPart.Text))
1037 continue
1038 }
1039 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1040 Role: "assistant",
1041 }
1042 for _, c := range msg.Content {
1043 switch c.GetType() {
1044 case ai.ContentTypeText:
1045 textPart, ok := ai.AsContentType[ai.TextPart](c)
1046 if !ok {
1047 warnings = append(warnings, ai.CallWarning{
1048 Type: ai.CallWarningTypeOther,
1049 Message: "assistant message text part does not have the right type",
1050 })
1051 continue
1052 }
1053 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1054 OfString: param.NewOpt(textPart.Text),
1055 }
1056 case ai.ContentTypeToolCall:
1057 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1058 if !ok {
1059 warnings = append(warnings, ai.CallWarning{
1060 Type: ai.CallWarningTypeOther,
1061 Message: "assistant message tool part does not have the right type",
1062 })
1063 continue
1064 }
1065 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1066 openai.ChatCompletionMessageToolCallUnionParam{
1067 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1068 ID: toolCallPart.ToolCallID,
1069 Type: "function",
1070 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1071 Name: toolCallPart.ToolName,
1072 Arguments: toolCallPart.Input,
1073 },
1074 },
1075 })
1076 }
1077 }
1078 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1079 OfAssistant: &assistantMsg,
1080 })
1081 case ai.MessageRoleTool:
1082 for _, c := range msg.Content {
1083 if c.GetType() != ai.ContentTypeToolResult {
1084 warnings = append(warnings, ai.CallWarning{
1085 Type: ai.CallWarningTypeOther,
1086 Message: "tool message can only have tool result content",
1087 })
1088 continue
1089 }
1090
1091 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1092 if !ok {
1093 warnings = append(warnings, ai.CallWarning{
1094 Type: ai.CallWarningTypeOther,
1095 Message: "tool message result part does not have the right type",
1096 })
1097 continue
1098 }
1099
1100 switch toolResultPart.Output.GetType() {
1101 case ai.ToolResultContentTypeText:
1102 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1103 if !ok {
1104 warnings = append(warnings, ai.CallWarning{
1105 Type: ai.CallWarningTypeOther,
1106 Message: "tool result output does not have the right type",
1107 })
1108 continue
1109 }
1110 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1111 case ai.ToolResultContentTypeError:
1112 // TODO: check if better handling is needed
1113 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1114 if !ok {
1115 warnings = append(warnings, ai.CallWarning{
1116 Type: ai.CallWarningTypeOther,
1117 Message: "tool result output does not have the right type",
1118 })
1119 continue
1120 }
1121 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1122 }
1123 }
1124 }
1125 }
1126 return messages, warnings
1127}
1128
1129// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1130func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1131 var annotations []openai.ChatCompletionMessageAnnotation
1132
1133 // Parse the raw JSON to extract annotations
1134 var deltaData map[string]any
1135 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1136 return annotations
1137 }
1138
1139 // Check if annotations exist in the delta
1140 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1141 for _, annotationData := range annotationsData {
1142 if annotationMap, ok := annotationData.(map[string]any); ok {
1143 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1144 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1145 annotation := openai.ChatCompletionMessageAnnotation{
1146 Type: "url_citation",
1147 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1148 URL: urlCitationData["url"].(string),
1149 Title: urlCitationData["title"].(string),
1150 },
1151 }
1152 annotations = append(annotations, annotation)
1153 }
1154 }
1155 }
1156 }
1157 }
1158
1159 return annotations
1160}