1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/ai"
14 "github.com/google/uuid"
15 "github.com/openai/openai-go/v2"
16 "github.com/openai/openai-go/v2/option"
17 "github.com/openai/openai-go/v2/packages/param"
18 "github.com/openai/openai-go/v2/shared"
19)
20
21type ReasoningEffort string
22
23const (
24 ReasoningEffortMinimal ReasoningEffort = "minimal"
25 ReasoningEffortLow ReasoningEffort = "low"
26 ReasoningEffortMedium ReasoningEffort = "medium"
27 ReasoningEffortHigh ReasoningEffort = "high"
28)
29
30type OpenAIProviderOptions struct {
31 LogitBias map[string]int64 `json:"logit_bias"`
32 LogProbs *bool `json:"log_probes"`
33 TopLogProbs *int64 `json:"top_log_probs"`
34 ParallelToolCalls *bool `json:"parallel_tool_calls"`
35 User *string `json:"user"`
36 ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
37 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
38 TextVerbosity *string `json:"text_verbosity"`
39 Prediction map[string]any `json:"prediction"`
40 Store *bool `json:"store"`
41 Metadata map[string]any `json:"metadata"`
42 PromptCacheKey *string `json:"prompt_cache_key"`
43 SafetyIdentifier *string `json:"safety_identifier"`
44 ServiceTier *string `json:"service_tier"`
45 StructuredOutputs *bool `json:"structured_outputs"`
46}
47
48type openAIProvider struct {
49 options openAIProviderOptions
50}
51
52type openAIProviderOptions struct {
53 baseURL string
54 apiKey string
55 organization string
56 project string
57 name string
58 headers map[string]string
59 client option.HTTPClient
60}
61
62type OpenAIOption = func(*openAIProviderOptions)
63
64func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
65 options := openAIProviderOptions{
66 headers: map[string]string{},
67 }
68 for _, o := range opts {
69 o(&options)
70 }
71
72 if options.baseURL == "" {
73 options.baseURL = "https://api.openai.com/v1"
74 }
75
76 if options.name == "" {
77 options.name = "openai"
78 }
79
80 if options.organization != "" {
81 options.headers["OpenAI-Organization"] = options.organization
82 }
83
84 if options.project != "" {
85 options.headers["OpenAI-Project"] = options.project
86 }
87
88 return &openAIProvider{
89 options: options,
90 }
91}
92
93func WithOpenAIBaseURL(baseURL string) OpenAIOption {
94 return func(o *openAIProviderOptions) {
95 o.baseURL = baseURL
96 }
97}
98
99func WithOpenAIApiKey(apiKey string) OpenAIOption {
100 return func(o *openAIProviderOptions) {
101 o.apiKey = apiKey
102 }
103}
104
105func WithOpenAIOrganization(organization string) OpenAIOption {
106 return func(o *openAIProviderOptions) {
107 o.organization = organization
108 }
109}
110
111func WithOpenAIProject(project string) OpenAIOption {
112 return func(o *openAIProviderOptions) {
113 o.project = project
114 }
115}
116
117func WithOpenAIName(name string) OpenAIOption {
118 return func(o *openAIProviderOptions) {
119 o.name = name
120 }
121}
122
123func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
124 return func(o *openAIProviderOptions) {
125 maps.Copy(o.headers, headers)
126 }
127}
128
129func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
130 return func(o *openAIProviderOptions) {
131 o.client = client
132 }
133}
134
135// LanguageModel implements ai.Provider.
136func (o *openAIProvider) LanguageModel(modelID string) (ai.LanguageModel, error) {
137 openaiClientOptions := []option.RequestOption{}
138 if o.options.apiKey != "" {
139 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
140 }
141 if o.options.baseURL != "" {
142 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
143 }
144
145 for key, value := range o.options.headers {
146 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
147 }
148
149 if o.options.client != nil {
150 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
151 }
152
153 return openAILanguageModel{
154 modelID: modelID,
155 provider: fmt.Sprintf("%s.chat", o.options.name),
156 providerOptions: o.options,
157 client: openai.NewClient(openaiClientOptions...),
158 }, nil
159}
160
161type openAILanguageModel struct {
162 provider string
163 modelID string
164 client openai.Client
165 providerOptions openAIProviderOptions
166}
167
168// Model implements ai.LanguageModel.
169func (o openAILanguageModel) Model() string {
170 return o.modelID
171}
172
173// Provider implements ai.LanguageModel.
174func (o openAILanguageModel) Provider() string {
175 return o.provider
176}
177
178func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
179 params := &openai.ChatCompletionNewParams{}
180 messages, warnings := toOpenAIPrompt(call.Prompt)
181 providerOptions := &OpenAIProviderOptions{}
182 if v, ok := call.ProviderOptions["openai"]; ok {
183 err := ai.ParseOptions(v, providerOptions)
184 if err != nil {
185 return nil, nil, err
186 }
187 }
188 if call.TopK != nil {
189 warnings = append(warnings, ai.CallWarning{
190 Type: ai.CallWarningTypeUnsupportedSetting,
191 Setting: "top_k",
192 })
193 }
194 params.Messages = messages
195 params.Model = o.modelID
196 if providerOptions.LogitBias != nil {
197 params.LogitBias = providerOptions.LogitBias
198 }
199 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
200 providerOptions.LogProbs = nil
201 }
202 if providerOptions.LogProbs != nil {
203 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
204 }
205 if providerOptions.TopLogProbs != nil {
206 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
207 }
208 if providerOptions.User != nil {
209 params.User = param.NewOpt(*providerOptions.User)
210 }
211 if providerOptions.ParallelToolCalls != nil {
212 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
213 }
214
215 if call.MaxOutputTokens != nil {
216 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
217 }
218 if call.Temperature != nil {
219 params.Temperature = param.NewOpt(*call.Temperature)
220 }
221 if call.TopP != nil {
222 params.TopP = param.NewOpt(*call.TopP)
223 }
224 if call.FrequencyPenalty != nil {
225 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
226 }
227 if call.PresencePenalty != nil {
228 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
229 }
230
231 if providerOptions.MaxCompletionTokens != nil {
232 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
233 }
234
235 if providerOptions.TextVerbosity != nil {
236 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
237 }
238 if providerOptions.Prediction != nil {
239 // Convert map[string]any to ChatCompletionPredictionContentParam
240 if content, ok := providerOptions.Prediction["content"]; ok {
241 if contentStr, ok := content.(string); ok {
242 params.Prediction = openai.ChatCompletionPredictionContentParam{
243 Content: openai.ChatCompletionPredictionContentContentUnionParam{
244 OfString: param.NewOpt(contentStr),
245 },
246 }
247 }
248 }
249 }
250 if providerOptions.Store != nil {
251 params.Store = param.NewOpt(*providerOptions.Store)
252 }
253 if providerOptions.Metadata != nil {
254 // Convert map[string]any to map[string]string
255 metadata := make(map[string]string)
256 for k, v := range providerOptions.Metadata {
257 if str, ok := v.(string); ok {
258 metadata[k] = str
259 }
260 }
261 params.Metadata = metadata
262 }
263 if providerOptions.PromptCacheKey != nil {
264 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
265 }
266 if providerOptions.SafetyIdentifier != nil {
267 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
268 }
269 if providerOptions.ServiceTier != nil {
270 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
271 }
272
273 if providerOptions.ReasoningEffort != nil {
274 switch *providerOptions.ReasoningEffort {
275 case ReasoningEffortMinimal:
276 params.ReasoningEffort = shared.ReasoningEffortMinimal
277 case ReasoningEffortLow:
278 params.ReasoningEffort = shared.ReasoningEffortLow
279 case ReasoningEffortMedium:
280 params.ReasoningEffort = shared.ReasoningEffortMedium
281 case ReasoningEffortHigh:
282 params.ReasoningEffort = shared.ReasoningEffortHigh
283 default:
284 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
285 }
286 }
287
288 if isReasoningModel(o.modelID) {
289 // remove unsupported settings for reasoning models
290 // see https://platform.openai.com/docs/guides/reasoning#limitations
291 if call.Temperature != nil {
292 params.Temperature = param.Opt[float64]{}
293 warnings = append(warnings, ai.CallWarning{
294 Type: ai.CallWarningTypeUnsupportedSetting,
295 Setting: "temperature",
296 Details: "temperature is not supported for reasoning models",
297 })
298 }
299 if call.TopP != nil {
300 params.TopP = param.Opt[float64]{}
301 warnings = append(warnings, ai.CallWarning{
302 Type: ai.CallWarningTypeUnsupportedSetting,
303 Setting: "TopP",
304 Details: "TopP is not supported for reasoning models",
305 })
306 }
307 if call.FrequencyPenalty != nil {
308 params.FrequencyPenalty = param.Opt[float64]{}
309 warnings = append(warnings, ai.CallWarning{
310 Type: ai.CallWarningTypeUnsupportedSetting,
311 Setting: "FrequencyPenalty",
312 Details: "FrequencyPenalty is not supported for reasoning models",
313 })
314 }
315 if call.PresencePenalty != nil {
316 params.PresencePenalty = param.Opt[float64]{}
317 warnings = append(warnings, ai.CallWarning{
318 Type: ai.CallWarningTypeUnsupportedSetting,
319 Setting: "PresencePenalty",
320 Details: "PresencePenalty is not supported for reasoning models",
321 })
322 }
323 if providerOptions.LogitBias != nil {
324 params.LogitBias = nil
325 warnings = append(warnings, ai.CallWarning{
326 Type: ai.CallWarningTypeUnsupportedSetting,
327 Setting: "LogitBias",
328 Message: "LogitBias is not supported for reasoning models",
329 })
330 }
331 if providerOptions.LogProbs != nil {
332 params.Logprobs = param.Opt[bool]{}
333 warnings = append(warnings, ai.CallWarning{
334 Type: ai.CallWarningTypeUnsupportedSetting,
335 Setting: "Logprobs",
336 Message: "Logprobs is not supported for reasoning models",
337 })
338 }
339 if providerOptions.TopLogProbs != nil {
340 params.TopLogprobs = param.Opt[int64]{}
341 warnings = append(warnings, ai.CallWarning{
342 Type: ai.CallWarningTypeUnsupportedSetting,
343 Setting: "TopLogprobs",
344 Message: "TopLogprobs is not supported for reasoning models",
345 })
346 }
347
348 // reasoning models use max_completion_tokens instead of max_tokens
349 if call.MaxOutputTokens != nil {
350 if providerOptions.MaxCompletionTokens == nil {
351 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
352 }
353 params.MaxTokens = param.Opt[int64]{}
354 }
355 }
356
357 // Handle search preview models
358 if isSearchPreviewModel(o.modelID) {
359 if call.Temperature != nil {
360 params.Temperature = param.Opt[float64]{}
361 warnings = append(warnings, ai.CallWarning{
362 Type: ai.CallWarningTypeUnsupportedSetting,
363 Setting: "temperature",
364 Details: "temperature is not supported for the search preview models and has been removed.",
365 })
366 }
367 }
368
369 // Handle service tier validation
370 if providerOptions.ServiceTier != nil {
371 serviceTier := *providerOptions.ServiceTier
372 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
373 params.ServiceTier = ""
374 warnings = append(warnings, ai.CallWarning{
375 Type: ai.CallWarningTypeUnsupportedSetting,
376 Setting: "ServiceTier",
377 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
378 })
379 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
380 params.ServiceTier = ""
381 warnings = append(warnings, ai.CallWarning{
382 Type: ai.CallWarningTypeUnsupportedSetting,
383 Setting: "ServiceTier",
384 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
385 })
386 }
387 }
388
389 if len(call.Tools) > 0 {
390 tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
391 params.Tools = tools
392 if toolChoice != nil {
393 params.ToolChoice = *toolChoice
394 }
395 warnings = append(warnings, toolWarnings...)
396 }
397 return params, warnings, nil
398}
399
400func (o openAILanguageModel) handleError(err error) error {
401 var apiErr *openai.Error
402 if errors.As(err, &apiErr) {
403 requestDump := apiErr.DumpRequest(true)
404 responseDump := apiErr.DumpResponse(true)
405 headers := map[string]string{}
406 for k, h := range apiErr.Response.Header {
407 v := h[len(h)-1]
408 headers[strings.ToLower(k)] = v
409 }
410 return ai.NewAPICallError(
411 apiErr.Message,
412 apiErr.Request.URL.String(),
413 string(requestDump),
414 apiErr.StatusCode,
415 headers,
416 string(responseDump),
417 apiErr,
418 false,
419 )
420 }
421 return err
422}
423
424// Generate implements ai.LanguageModel.
425func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
426 params, warnings, err := o.prepareParams(call)
427 if err != nil {
428 return nil, err
429 }
430 response, err := o.client.Chat.Completions.New(ctx, *params)
431 if err != nil {
432 return nil, o.handleError(err)
433 }
434
435 if len(response.Choices) == 0 {
436 return nil, errors.New("no response generated")
437 }
438 choice := response.Choices[0]
439 var content []ai.Content
440 text := choice.Message.Content
441 if text != "" {
442 content = append(content, ai.TextContent{
443 Text: text,
444 })
445 }
446
447 for _, tc := range choice.Message.ToolCalls {
448 toolCallID := tc.ID
449 if toolCallID == "" {
450 toolCallID = uuid.NewString()
451 }
452 content = append(content, ai.ToolCallContent{
453 ProviderExecuted: false, // TODO: update when handling other tools
454 ToolCallID: toolCallID,
455 ToolName: tc.Function.Name,
456 Input: tc.Function.Arguments,
457 })
458 }
459 // Handle annotations/citations
460 for _, annotation := range choice.Message.Annotations {
461 if annotation.Type == "url_citation" {
462 content = append(content, ai.SourceContent{
463 SourceType: ai.SourceTypeURL,
464 ID: uuid.NewString(),
465 URL: annotation.URLCitation.URL,
466 Title: annotation.URLCitation.Title,
467 })
468 }
469 }
470
471 completionTokenDetails := response.Usage.CompletionTokensDetails
472 promptTokenDetails := response.Usage.PromptTokensDetails
473
474 // Build provider metadata
475 providerMetadata := ai.ProviderMetadata{
476 "openai": make(map[string]any),
477 }
478
479 // Add logprobs if available
480 if len(choice.Logprobs.Content) > 0 {
481 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
482 }
483
484 // Add prediction tokens if available
485 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
486 if completionTokenDetails.AcceptedPredictionTokens > 0 {
487 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
488 }
489 if completionTokenDetails.RejectedPredictionTokens > 0 {
490 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
491 }
492 }
493
494 return &ai.Response{
495 Content: content,
496 Usage: ai.Usage{
497 InputTokens: response.Usage.PromptTokens,
498 OutputTokens: response.Usage.CompletionTokens,
499 TotalTokens: response.Usage.TotalTokens,
500 ReasoningTokens: completionTokenDetails.ReasoningTokens,
501 CacheReadTokens: promptTokenDetails.CachedTokens,
502 },
503 FinishReason: mapOpenAIFinishReason(choice.FinishReason),
504 ProviderMetadata: providerMetadata,
505 Warnings: warnings,
506 }, nil
507}
508
509type toolCall struct {
510 id string
511 name string
512 arguments string
513 hasFinished bool
514}
515
516// Stream implements ai.LanguageModel.
517func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
518 params, warnings, err := o.prepareParams(call)
519 if err != nil {
520 return nil, err
521 }
522
523 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
524 IncludeUsage: openai.Bool(true),
525 }
526
527 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
528 isActiveText := false
529 toolCalls := make(map[int64]toolCall)
530
531 // Build provider metadata for streaming
532 streamProviderMetadata := ai.ProviderMetadata{
533 "openai": make(map[string]any),
534 }
535
536 acc := openai.ChatCompletionAccumulator{}
537 var usage ai.Usage
538 return func(yield func(ai.StreamPart) bool) {
539 if len(warnings) > 0 {
540 if !yield(ai.StreamPart{
541 Type: ai.StreamPartTypeWarnings,
542 Warnings: warnings,
543 }) {
544 return
545 }
546 }
547 for stream.Next() {
548 chunk := stream.Current()
549 acc.AddChunk(chunk)
550 if chunk.Usage.TotalTokens > 0 {
551 // we do this here because the acc does not add prompt details
552 completionTokenDetails := chunk.Usage.CompletionTokensDetails
553 promptTokenDetails := chunk.Usage.PromptTokensDetails
554 usage = ai.Usage{
555 InputTokens: chunk.Usage.PromptTokens,
556 OutputTokens: chunk.Usage.CompletionTokens,
557 TotalTokens: chunk.Usage.TotalTokens,
558 ReasoningTokens: completionTokenDetails.ReasoningTokens,
559 CacheReadTokens: promptTokenDetails.CachedTokens,
560 }
561
562 // Add prediction tokens if available
563 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
564 if completionTokenDetails.AcceptedPredictionTokens > 0 {
565 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
566 }
567 if completionTokenDetails.RejectedPredictionTokens > 0 {
568 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
569 }
570 }
571 }
572 if len(chunk.Choices) == 0 {
573 continue
574 }
575 for _, choice := range chunk.Choices {
576 switch {
577 case choice.Delta.Content != "":
578 if !isActiveText {
579 isActiveText = true
580 if !yield(ai.StreamPart{
581 Type: ai.StreamPartTypeTextStart,
582 ID: "0",
583 }) {
584 return
585 }
586 }
587 if !yield(ai.StreamPart{
588 Type: ai.StreamPartTypeTextDelta,
589 ID: "0",
590 Delta: choice.Delta.Content,
591 }) {
592 return
593 }
594 case len(choice.Delta.ToolCalls) > 0:
595 if isActiveText {
596 isActiveText = false
597 if !yield(ai.StreamPart{
598 Type: ai.StreamPartTypeTextEnd,
599 ID: "0",
600 }) {
601 return
602 }
603 }
604
605 for _, toolCallDelta := range choice.Delta.ToolCalls {
606 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
607 if existingToolCall.hasFinished {
608 continue
609 }
610 if toolCallDelta.Function.Arguments != "" {
611 existingToolCall.arguments += toolCallDelta.Function.Arguments
612 }
613 if !yield(ai.StreamPart{
614 Type: ai.StreamPartTypeToolInputDelta,
615 ID: existingToolCall.id,
616 Delta: toolCallDelta.Function.Arguments,
617 }) {
618 return
619 }
620 toolCalls[toolCallDelta.Index] = existingToolCall
621 if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
622 if !yield(ai.StreamPart{
623 Type: ai.StreamPartTypeToolInputEnd,
624 ID: existingToolCall.id,
625 }) {
626 return
627 }
628
629 if !yield(ai.StreamPart{
630 Type: ai.StreamPartTypeToolCall,
631 ID: existingToolCall.id,
632 ToolCallName: existingToolCall.name,
633 ToolCallInput: existingToolCall.arguments,
634 }) {
635 return
636 }
637 existingToolCall.hasFinished = true
638 toolCalls[toolCallDelta.Index] = existingToolCall
639 }
640 } else {
641 // Does not exist
642 var err error
643 if toolCallDelta.Type != "function" {
644 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
645 }
646 if toolCallDelta.ID == "" {
647 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
648 }
649 if toolCallDelta.Function.Name == "" {
650 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
651 }
652 if err != nil {
653 yield(ai.StreamPart{
654 Type: ai.StreamPartTypeError,
655 Error: o.handleError(stream.Err()),
656 })
657 return
658 }
659
660 if !yield(ai.StreamPart{
661 Type: ai.StreamPartTypeToolInputStart,
662 ID: toolCallDelta.ID,
663 ToolCallName: toolCallDelta.Function.Name,
664 }) {
665 return
666 }
667 toolCalls[toolCallDelta.Index] = toolCall{
668 id: toolCallDelta.ID,
669 name: toolCallDelta.Function.Name,
670 arguments: toolCallDelta.Function.Arguments,
671 }
672
673 exTc := toolCalls[toolCallDelta.Index]
674 if exTc.arguments != "" {
675 if !yield(ai.StreamPart{
676 Type: ai.StreamPartTypeToolInputDelta,
677 ID: exTc.id,
678 Delta: exTc.arguments,
679 }) {
680 return
681 }
682 if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
683 if !yield(ai.StreamPart{
684 Type: ai.StreamPartTypeToolInputEnd,
685 ID: toolCallDelta.ID,
686 }) {
687 return
688 }
689
690 if !yield(ai.StreamPart{
691 Type: ai.StreamPartTypeToolCall,
692 ID: exTc.id,
693 ToolCallName: exTc.name,
694 ToolCallInput: exTc.arguments,
695 }) {
696 return
697 }
698 exTc.hasFinished = true
699 toolCalls[toolCallDelta.Index] = exTc
700 }
701 }
702 continue
703 }
704 }
705 }
706 }
707
708 // Check for annotations in the delta's raw JSON
709 for _, choice := range chunk.Choices {
710 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
711 for _, annotation := range annotations {
712 if annotation.Type == "url_citation" {
713 if !yield(ai.StreamPart{
714 Type: ai.StreamPartTypeSource,
715 ID: uuid.NewString(),
716 SourceType: ai.SourceTypeURL,
717 URL: annotation.URLCitation.URL,
718 Title: annotation.URLCitation.Title,
719 }) {
720 return
721 }
722 }
723 }
724 }
725 }
726 }
727 err := stream.Err()
728 if err == nil || errors.Is(err, io.EOF) {
729 // finished
730 if isActiveText {
731 isActiveText = false
732 if !yield(ai.StreamPart{
733 Type: ai.StreamPartTypeTextEnd,
734 ID: "0",
735 }) {
736 return
737 }
738 }
739
740 // Add logprobs if available
741 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
742 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
743 }
744
745 // Handle annotations/citations from accumulated response
746 if len(acc.Choices) > 0 {
747 for _, annotation := range acc.Choices[0].Message.Annotations {
748 if annotation.Type == "url_citation" {
749 if !yield(ai.StreamPart{
750 Type: ai.StreamPartTypeSource,
751 ID: acc.ID,
752 SourceType: ai.SourceTypeURL,
753 URL: annotation.URLCitation.URL,
754 Title: annotation.URLCitation.Title,
755 }) {
756 return
757 }
758 }
759 }
760 }
761
762 finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
763 yield(ai.StreamPart{
764 Type: ai.StreamPartTypeFinish,
765 Usage: usage,
766 FinishReason: finishReason,
767 ProviderMetadata: streamProviderMetadata,
768 })
769 return
770 } else {
771 yield(ai.StreamPart{
772 Type: ai.StreamPartTypeError,
773 Error: o.handleError(err),
774 })
775 return
776 }
777 }, nil
778}
779
780func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
781 switch finishReason {
782 case "stop":
783 return ai.FinishReasonStop
784 case "length":
785 return ai.FinishReasonLength
786 case "content_filter":
787 return ai.FinishReasonContentFilter
788 case "function_call", "tool_calls":
789 return ai.FinishReasonToolCalls
790 default:
791 return ai.FinishReasonUnknown
792 }
793}
794
795func isReasoningModel(modelID string) bool {
796 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
797}
798
799func isSearchPreviewModel(modelID string) bool {
800 return strings.Contains(modelID, "search-preview")
801}
802
803func supportsFlexProcessing(modelID string) bool {
804 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
805}
806
807func supportsPriorityProcessing(modelID string) bool {
808 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
809 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
810 strings.HasPrefix(modelID, "o4-mini")
811}
812
813func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
814 for _, tool := range tools {
815 if tool.GetType() == ai.ToolTypeFunction {
816 ft, ok := tool.(ai.FunctionTool)
817 if !ok {
818 continue
819 }
820 openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
821 OfFunction: &openai.ChatCompletionFunctionToolParam{
822 Function: shared.FunctionDefinitionParam{
823 Name: ft.Name,
824 Description: param.NewOpt(ft.Description),
825 Parameters: openai.FunctionParameters(ft.InputSchema),
826 Strict: param.NewOpt(false),
827 },
828 Type: "function",
829 },
830 })
831 continue
832 }
833
834 // TODO: handle provider tool calls
835 warnings = append(warnings, ai.CallWarning{
836 Type: ai.CallWarningTypeUnsupportedTool,
837 Tool: tool,
838 Message: "tool is not supported",
839 })
840 }
841 if toolChoice == nil {
842 return
843 }
844
845 switch *toolChoice {
846 case ai.ToolChoiceAuto:
847 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
848 OfAuto: param.NewOpt("auto"),
849 }
850 case ai.ToolChoiceNone:
851 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
852 OfAuto: param.NewOpt("none"),
853 }
854 default:
855 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
856 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
857 Type: "function",
858 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
859 Name: string(*toolChoice),
860 },
861 },
862 }
863 }
864 return
865}
866
867func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
868 var messages []openai.ChatCompletionMessageParamUnion
869 var warnings []ai.CallWarning
870 for _, msg := range prompt {
871 switch msg.Role {
872 case ai.MessageRoleSystem:
873 var systemPromptParts []string
874 for _, c := range msg.Content {
875 if c.GetType() != ai.ContentTypeText {
876 warnings = append(warnings, ai.CallWarning{
877 Type: ai.CallWarningTypeOther,
878 Message: "system prompt can only have text content",
879 })
880 continue
881 }
882 textPart, ok := ai.AsContentType[ai.TextPart](c)
883 if !ok {
884 warnings = append(warnings, ai.CallWarning{
885 Type: ai.CallWarningTypeOther,
886 Message: "system prompt text part does not have the right type",
887 })
888 continue
889 }
890 text := textPart.Text
891 if strings.TrimSpace(text) != "" {
892 systemPromptParts = append(systemPromptParts, textPart.Text)
893 }
894 }
895 if len(systemPromptParts) == 0 {
896 warnings = append(warnings, ai.CallWarning{
897 Type: ai.CallWarningTypeOther,
898 Message: "system prompt has no text parts",
899 })
900 continue
901 }
902 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
903 case ai.MessageRoleUser:
904 // simple user message just text content
905 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
906 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
907 if !ok {
908 warnings = append(warnings, ai.CallWarning{
909 Type: ai.CallWarningTypeOther,
910 Message: "user message text part does not have the right type",
911 })
912 continue
913 }
914 messages = append(messages, openai.UserMessage(textPart.Text))
915 continue
916 }
917 // text content and attachments
918 // for now we only support image content later we need to check
919 // TODO: add the supported media types to the language model so we
920 // can use that to validate the data here.
921 var content []openai.ChatCompletionContentPartUnionParam
922 for _, c := range msg.Content {
923 switch c.GetType() {
924 case ai.ContentTypeText:
925 textPart, ok := ai.AsContentType[ai.TextPart](c)
926 if !ok {
927 warnings = append(warnings, ai.CallWarning{
928 Type: ai.CallWarningTypeOther,
929 Message: "user message text part does not have the right type",
930 })
931 continue
932 }
933 content = append(content, openai.ChatCompletionContentPartUnionParam{
934 OfText: &openai.ChatCompletionContentPartTextParam{
935 Text: textPart.Text,
936 },
937 })
938 case ai.ContentTypeFile:
939 filePart, ok := ai.AsContentType[ai.FilePart](c)
940 if !ok {
941 warnings = append(warnings, ai.CallWarning{
942 Type: ai.CallWarningTypeOther,
943 Message: "user message file part does not have the right type",
944 })
945 continue
946 }
947
948 switch {
949 case strings.HasPrefix(filePart.MediaType, "image/"):
950 // Handle image files
951 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
952 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
953 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
954
955 // Check for provider-specific options like image detail
956 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
957 if detail, ok := providerOptions["imageDetail"].(string); ok {
958 imageURL.Detail = detail
959 }
960 }
961
962 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
963 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
964
965 case filePart.MediaType == "audio/wav":
966 // Handle WAV audio files
967 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
968 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
969 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
970 Data: base64Encoded,
971 Format: "wav",
972 },
973 }
974 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
975
976 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
977 // Handle MP3 audio files
978 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
979 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
980 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
981 Data: base64Encoded,
982 Format: "mp3",
983 },
984 }
985 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
986
987 case filePart.MediaType == "application/pdf":
988 // Handle PDF files
989 dataStr := string(filePart.Data)
990
991 // Check if data looks like a file ID (starts with "file-")
992 if strings.HasPrefix(dataStr, "file-") {
993 fileBlock := openai.ChatCompletionContentPartFileParam{
994 File: openai.ChatCompletionContentPartFileFileParam{
995 FileID: param.NewOpt(dataStr),
996 },
997 }
998 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
999 } else {
1000 // Handle as base64 data
1001 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1002 data := "data:application/pdf;base64," + base64Encoded
1003
1004 filename := filePart.Filename
1005 if filename == "" {
1006 // Generate default filename based on content index
1007 filename = fmt.Sprintf("part-%d.pdf", len(content))
1008 }
1009
1010 fileBlock := openai.ChatCompletionContentPartFileParam{
1011 File: openai.ChatCompletionContentPartFileFileParam{
1012 Filename: param.NewOpt(filename),
1013 FileData: param.NewOpt(data),
1014 },
1015 }
1016 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1017 }
1018
1019 default:
1020 warnings = append(warnings, ai.CallWarning{
1021 Type: ai.CallWarningTypeOther,
1022 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1023 })
1024 }
1025 }
1026 }
1027 messages = append(messages, openai.UserMessage(content))
1028 case ai.MessageRoleAssistant:
1029 // simple assistant message just text content
1030 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1031 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1032 if !ok {
1033 warnings = append(warnings, ai.CallWarning{
1034 Type: ai.CallWarningTypeOther,
1035 Message: "assistant message text part does not have the right type",
1036 })
1037 continue
1038 }
1039 messages = append(messages, openai.AssistantMessage(textPart.Text))
1040 continue
1041 }
1042 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1043 Role: "assistant",
1044 }
1045 for _, c := range msg.Content {
1046 switch c.GetType() {
1047 case ai.ContentTypeText:
1048 textPart, ok := ai.AsContentType[ai.TextPart](c)
1049 if !ok {
1050 warnings = append(warnings, ai.CallWarning{
1051 Type: ai.CallWarningTypeOther,
1052 Message: "assistant message text part does not have the right type",
1053 })
1054 continue
1055 }
1056 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1057 OfString: param.NewOpt(textPart.Text),
1058 }
1059 case ai.ContentTypeToolCall:
1060 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1061 if !ok {
1062 warnings = append(warnings, ai.CallWarning{
1063 Type: ai.CallWarningTypeOther,
1064 Message: "assistant message tool part does not have the right type",
1065 })
1066 continue
1067 }
1068 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1069 openai.ChatCompletionMessageToolCallUnionParam{
1070 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1071 ID: toolCallPart.ToolCallID,
1072 Type: "function",
1073 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1074 Name: toolCallPart.ToolName,
1075 Arguments: toolCallPart.Input,
1076 },
1077 },
1078 })
1079 }
1080 }
1081 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1082 OfAssistant: &assistantMsg,
1083 })
1084 case ai.MessageRoleTool:
1085 for _, c := range msg.Content {
1086 if c.GetType() != ai.ContentTypeToolResult {
1087 warnings = append(warnings, ai.CallWarning{
1088 Type: ai.CallWarningTypeOther,
1089 Message: "tool message can only have tool result content",
1090 })
1091 continue
1092 }
1093
1094 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1095 if !ok {
1096 warnings = append(warnings, ai.CallWarning{
1097 Type: ai.CallWarningTypeOther,
1098 Message: "tool message result part does not have the right type",
1099 })
1100 continue
1101 }
1102
1103 switch toolResultPart.Output.GetType() {
1104 case ai.ToolResultContentTypeText:
1105 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1106 if !ok {
1107 warnings = append(warnings, ai.CallWarning{
1108 Type: ai.CallWarningTypeOther,
1109 Message: "tool result output does not have the right type",
1110 })
1111 continue
1112 }
1113 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1114 case ai.ToolResultContentTypeError:
1115 // TODO: check if better handling is needed
1116 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1117 if !ok {
1118 warnings = append(warnings, ai.CallWarning{
1119 Type: ai.CallWarningTypeOther,
1120 Message: "tool result output does not have the right type",
1121 })
1122 continue
1123 }
1124 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1125 }
1126 }
1127 }
1128 }
1129 return messages, warnings
1130}
1131
1132// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1133func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1134 var annotations []openai.ChatCompletionMessageAnnotation
1135
1136 // Parse the raw JSON to extract annotations
1137 var deltaData map[string]any
1138 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1139 return annotations
1140 }
1141
1142 // Check if annotations exist in the delta
1143 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1144 for _, annotationData := range annotationsData {
1145 if annotationMap, ok := annotationData.(map[string]any); ok {
1146 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1147 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1148 annotation := openai.ChatCompletionMessageAnnotation{
1149 Type: "url_citation",
1150 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1151 URL: urlCitationData["url"].(string),
1152 Title: urlCitationData["title"].(string),
1153 },
1154 }
1155 annotations = append(annotations, annotation)
1156 }
1157 }
1158 }
1159 }
1160 }
1161
1162 return annotations
1163}