1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "strings"
12
13 "github.com/charmbracelet/crush/internal/ai"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/env"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23type ReasoningEffort string
24
25const (
26 ReasoningEffortMinimal ReasoningEffort = "minimal"
27 ReasoningEffortLow ReasoningEffort = "low"
28 ReasoningEffortMedium ReasoningEffort = "medium"
29 ReasoningEffortHigh ReasoningEffort = "high"
30)
31
32type OpenAIProviderOptions struct {
33 LogitBias map[string]int64 `json:"logit_bias"`
34 LogProbs *bool `json:"log_probes"`
35 TopLogProbs *int64 `json:"top_log_probs"`
36 ParallelToolCalls *bool `json:"parallel_tool_calls"`
37 User *string `json:"user"`
38 ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
39 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
40 TextVerbosity *string `json:"text_verbosity"`
41 Prediction map[string]any `json:"prediction"`
42 Store *bool `json:"store"`
43 Metadata map[string]any `json:"metadata"`
44 PromptCacheKey *string `json:"prompt_cache_key"`
45 SafetyIdentifier *string `json:"safety_identifier"`
46 ServiceTier *string `json:"service_tier"`
47 StructuredOutputs *bool `json:"structured_outputs"`
48}
49
50type openAIProvider struct {
51 options openAIProviderOptions
52}
53
54type openAIProviderOptions struct {
55 baseURL string
56 apiKey string
57 organization string
58 project string
59 name string
60 headers map[string]string
61 client option.HTTPClient
62 resolver config.VariableResolver
63}
64
65type OpenAIOption = func(*openAIProviderOptions)
66
67func NewOpenAIProvider(opts ...OpenAIOption) ai.Provider {
68 options := openAIProviderOptions{
69 headers: map[string]string{},
70 }
71 for _, o := range opts {
72 o(&options)
73 }
74
75 if options.resolver == nil {
76 // use the default resolver
77 options.resolver = config.NewShellVariableResolver(env.New())
78 }
79 options.apiKey, _ = options.resolver.ResolveValue(options.apiKey)
80 options.baseURL, _ = options.resolver.ResolveValue(options.baseURL)
81 if options.baseURL == "" {
82 options.baseURL = "https://api.openai.com/v1"
83 }
84
85 options.name, _ = options.resolver.ResolveValue(options.name)
86 if options.name == "" {
87 options.name = "openai"
88 }
89
90 for k, v := range options.headers {
91 options.headers[k], _ = options.resolver.ResolveValue(v)
92 }
93
94 options.organization, _ = options.resolver.ResolveValue(options.organization)
95 if options.organization != "" {
96 options.headers["OpenAI-Organization"] = options.organization
97 }
98
99 options.project, _ = options.resolver.ResolveValue(options.project)
100 if options.project != "" {
101 options.headers["OpenAI-Project"] = options.project
102 }
103
104 return &openAIProvider{
105 options: options,
106 }
107}
108
109func WithOpenAIBaseURL(baseURL string) OpenAIOption {
110 return func(o *openAIProviderOptions) {
111 o.baseURL = baseURL
112 }
113}
114
115func WithOpenAIApiKey(apiKey string) OpenAIOption {
116 return func(o *openAIProviderOptions) {
117 o.apiKey = apiKey
118 }
119}
120
121func WithOpenAIOrganization(organization string) OpenAIOption {
122 return func(o *openAIProviderOptions) {
123 o.organization = organization
124 }
125}
126
127func WithOpenAIProject(project string) OpenAIOption {
128 return func(o *openAIProviderOptions) {
129 o.project = project
130 }
131}
132
133func WithOpenAIName(name string) OpenAIOption {
134 return func(o *openAIProviderOptions) {
135 o.name = name
136 }
137}
138
139func WithOpenAIHeaders(headers map[string]string) OpenAIOption {
140 return func(o *openAIProviderOptions) {
141 maps.Copy(o.headers, headers)
142 }
143}
144
145func WithOpenAIHttpClient(client option.HTTPClient) OpenAIOption {
146 return func(o *openAIProviderOptions) {
147 o.client = client
148 }
149}
150
151func WithOpenAIVariableResolver(resolver config.VariableResolver) OpenAIOption {
152 return func(o *openAIProviderOptions) {
153 o.resolver = resolver
154 }
155}
156
157// LanguageModel implements ai.Provider.
158func (o *openAIProvider) LanguageModel(modelID string) ai.LanguageModel {
159 openaiClientOptions := []option.RequestOption{}
160 if o.options.apiKey != "" {
161 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
162 }
163 if o.options.baseURL != "" {
164 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
165 }
166
167 for key, value := range o.options.headers {
168 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
169 }
170
171 if o.options.client != nil {
172 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
173 }
174
175 return openAILanguageModel{
176 modelID: modelID,
177 provider: fmt.Sprintf("%s.chat", o.options.name),
178 providerOptions: o.options,
179 client: openai.NewClient(openaiClientOptions...),
180 }
181}
182
183type openAILanguageModel struct {
184 provider string
185 modelID string
186 client openai.Client
187 providerOptions openAIProviderOptions
188}
189
190// Model implements ai.LanguageModel.
191func (o openAILanguageModel) Model() string {
192 return o.modelID
193}
194
195// Provider implements ai.LanguageModel.
196func (o openAILanguageModel) Provider() string {
197 return o.provider
198}
199
200func (o openAILanguageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
201 params := &openai.ChatCompletionNewParams{}
202 messages, warnings := toOpenAIPrompt(call.Prompt)
203 providerOptions := &OpenAIProviderOptions{}
204 if v, ok := call.ProviderOptions["openai"]; ok {
205 err := ai.ParseOptions(v, providerOptions)
206 if err != nil {
207 return nil, nil, err
208 }
209 }
210 if call.TopK != nil {
211 warnings = append(warnings, ai.CallWarning{
212 Type: ai.CallWarningTypeUnsupportedSetting,
213 Setting: "top_k",
214 })
215 }
216 params.Messages = messages
217 params.Model = o.modelID
218 if providerOptions.LogitBias != nil {
219 params.LogitBias = providerOptions.LogitBias
220 }
221 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
222 providerOptions.LogProbs = nil
223 }
224 if providerOptions.LogProbs != nil {
225 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
226 }
227 if providerOptions.TopLogProbs != nil {
228 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
229 }
230 if providerOptions.User != nil {
231 params.User = param.NewOpt(*providerOptions.User)
232 }
233 if providerOptions.ParallelToolCalls != nil {
234 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
235 }
236
237 if call.MaxOutputTokens != nil {
238 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
239 }
240 if call.Temperature != nil {
241 params.Temperature = param.NewOpt(*call.Temperature)
242 }
243 if call.TopP != nil {
244 params.TopP = param.NewOpt(*call.TopP)
245 }
246 if call.FrequencyPenalty != nil {
247 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
248 }
249 if call.PresencePenalty != nil {
250 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
251 }
252
253 if providerOptions.MaxCompletionTokens != nil {
254 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
255 }
256
257 if providerOptions.TextVerbosity != nil {
258 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
259 }
260 if providerOptions.Prediction != nil {
261 // Convert map[string]any to ChatCompletionPredictionContentParam
262 if content, ok := providerOptions.Prediction["content"]; ok {
263 if contentStr, ok := content.(string); ok {
264 params.Prediction = openai.ChatCompletionPredictionContentParam{
265 Content: openai.ChatCompletionPredictionContentContentUnionParam{
266 OfString: param.NewOpt(contentStr),
267 },
268 }
269 }
270 }
271 }
272 if providerOptions.Store != nil {
273 params.Store = param.NewOpt(*providerOptions.Store)
274 }
275 if providerOptions.Metadata != nil {
276 // Convert map[string]any to map[string]string
277 metadata := make(map[string]string)
278 for k, v := range providerOptions.Metadata {
279 if str, ok := v.(string); ok {
280 metadata[k] = str
281 }
282 }
283 params.Metadata = metadata
284 }
285 if providerOptions.PromptCacheKey != nil {
286 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
287 }
288 if providerOptions.SafetyIdentifier != nil {
289 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
290 }
291 if providerOptions.ServiceTier != nil {
292 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
293 }
294
295 if providerOptions.ReasoningEffort != nil {
296 switch *providerOptions.ReasoningEffort {
297 case ReasoningEffortMinimal:
298 params.ReasoningEffort = shared.ReasoningEffortMinimal
299 case ReasoningEffortLow:
300 params.ReasoningEffort = shared.ReasoningEffortLow
301 case ReasoningEffortMedium:
302 params.ReasoningEffort = shared.ReasoningEffortMedium
303 case ReasoningEffortHigh:
304 params.ReasoningEffort = shared.ReasoningEffortHigh
305 default:
306 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
307 }
308 }
309
310 if isReasoningModel(o.modelID) {
311 // remove unsupported settings for reasoning models
312 // see https://platform.openai.com/docs/guides/reasoning#limitations
313 if call.Temperature != nil {
314 params.Temperature = param.Opt[float64]{}
315 warnings = append(warnings, ai.CallWarning{
316 Type: ai.CallWarningTypeUnsupportedSetting,
317 Setting: "temperature",
318 Details: "temperature is not supported for reasoning models",
319 })
320 }
321 if call.TopP != nil {
322 params.TopP = param.Opt[float64]{}
323 warnings = append(warnings, ai.CallWarning{
324 Type: ai.CallWarningTypeUnsupportedSetting,
325 Setting: "top_p",
326 Details: "topP is not supported for reasoning models",
327 })
328 }
329 if call.FrequencyPenalty != nil {
330 params.FrequencyPenalty = param.Opt[float64]{}
331 warnings = append(warnings, ai.CallWarning{
332 Type: ai.CallWarningTypeUnsupportedSetting,
333 Setting: "frequency_penalty",
334 Details: "frequencyPenalty is not supported for reasoning models",
335 })
336 }
337 if call.PresencePenalty != nil {
338 params.PresencePenalty = param.Opt[float64]{}
339 warnings = append(warnings, ai.CallWarning{
340 Type: ai.CallWarningTypeUnsupportedSetting,
341 Setting: "presence_penalty",
342 Details: "presencePenalty is not supported for reasoning models",
343 })
344 }
345 if providerOptions.LogitBias != nil {
346 params.LogitBias = nil
347 warnings = append(warnings, ai.CallWarning{
348 Type: ai.CallWarningTypeOther,
349 Message: "logitBias is not supported for reasoning models",
350 })
351 }
352 if providerOptions.LogProbs != nil {
353 params.Logprobs = param.Opt[bool]{}
354 warnings = append(warnings, ai.CallWarning{
355 Type: ai.CallWarningTypeOther,
356 Message: "logprobs is not supported for reasoning models",
357 })
358 }
359 if providerOptions.TopLogProbs != nil {
360 params.TopLogprobs = param.Opt[int64]{}
361 warnings = append(warnings, ai.CallWarning{
362 Type: ai.CallWarningTypeOther,
363 Message: "topLogprobs is not supported for reasoning models",
364 })
365 }
366
367 // reasoning models use max_completion_tokens instead of max_tokens
368 if call.MaxOutputTokens != nil {
369 if providerOptions.MaxCompletionTokens == nil {
370 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
371 }
372 params.MaxTokens = param.Opt[int64]{}
373 }
374 }
375
376 // Handle search preview models
377 if isSearchPreviewModel(o.modelID) {
378 if call.Temperature != nil {
379 params.Temperature = param.Opt[float64]{}
380 warnings = append(warnings, ai.CallWarning{
381 Type: ai.CallWarningTypeUnsupportedSetting,
382 Setting: "temperature",
383 Details: "temperature is not supported for the search preview models and has been removed.",
384 })
385 }
386 }
387
388 // Handle service tier validation
389 if providerOptions.ServiceTier != nil {
390 serviceTier := *providerOptions.ServiceTier
391 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
392 params.ServiceTier = ""
393 warnings = append(warnings, ai.CallWarning{
394 Type: ai.CallWarningTypeUnsupportedSetting,
395 Setting: "serviceTier",
396 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
397 })
398 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
399 params.ServiceTier = ""
400 warnings = append(warnings, ai.CallWarning{
401 Type: ai.CallWarningTypeUnsupportedSetting,
402 Setting: "serviceTier",
403 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
404 })
405 }
406 }
407
408 if len(call.Tools) > 0 {
409 tools, toolChoice, toolWarnings := toOpenAITools(call.Tools, call.ToolChoice)
410 params.Tools = tools
411 if toolChoice != nil {
412 params.ToolChoice = *toolChoice
413 }
414 warnings = append(warnings, toolWarnings...)
415 }
416 return params, warnings, nil
417}
418
419// Generate implements ai.LanguageModel.
420func (o openAILanguageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
421 params, warnings, err := o.prepareParams(call)
422 if err != nil {
423 return nil, err
424 }
425 response, err := o.client.Chat.Completions.New(ctx, *params)
426 if err != nil {
427 return nil, err
428 }
429
430 if len(response.Choices) == 0 {
431 return nil, errors.New("no response generated")
432 }
433 choice := response.Choices[0]
434 var content []ai.Content
435 text := choice.Message.Content
436 if text != "" {
437 content = append(content, ai.TextContent{
438 Text: text,
439 })
440 }
441
442 for _, tc := range choice.Message.ToolCalls {
443 toolCallID := tc.ID
444 if toolCallID == "" {
445 toolCallID = uuid.NewString()
446 }
447 content = append(content, ai.ToolCallContent{
448 ProviderExecuted: false, // TODO: update when handling other tools
449 ToolCallID: toolCallID,
450 ToolName: tc.Function.Name,
451 Input: tc.Function.Arguments,
452 })
453 }
454 // Handle annotations/citations
455 for _, annotation := range choice.Message.Annotations {
456 if annotation.Type == "url_citation" {
457 content = append(content, ai.SourceContent{
458 SourceType: ai.SourceTypeURL,
459 ID: uuid.NewString(),
460 URL: annotation.URLCitation.URL,
461 Title: annotation.URLCitation.Title,
462 })
463 }
464 }
465
466 completionTokenDetails := response.Usage.CompletionTokensDetails
467 promptTokenDetails := response.Usage.PromptTokensDetails
468
469 // Build provider metadata
470 providerMetadata := ai.ProviderMetadata{
471 "openai": make(map[string]any),
472 }
473
474 // Add logprobs if available
475 if len(choice.Logprobs.Content) > 0 {
476 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
477 }
478
479 // Add prediction tokens if available
480 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
481 if completionTokenDetails.AcceptedPredictionTokens > 0 {
482 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
483 }
484 if completionTokenDetails.RejectedPredictionTokens > 0 {
485 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
486 }
487 }
488
489 return &ai.Response{
490 Content: content,
491 Usage: ai.Usage{
492 InputTokens: response.Usage.PromptTokens,
493 OutputTokens: response.Usage.CompletionTokens,
494 TotalTokens: response.Usage.TotalTokens,
495 ReasoningTokens: completionTokenDetails.ReasoningTokens,
496 CacheReadTokens: promptTokenDetails.CachedTokens,
497 },
498 FinishReason: mapOpenAIFinishReason(choice.FinishReason),
499 ProviderMetadata: providerMetadata,
500 Warnings: warnings,
501 }, nil
502}
503
504type toolCall struct {
505 id string
506 name string
507 arguments string
508 hasFinished bool
509}
510
511// Stream implements ai.LanguageModel.
512func (o openAILanguageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
513 params, warnings, err := o.prepareParams(call)
514 if err != nil {
515 return nil, err
516 }
517
518 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
519 IncludeUsage: openai.Bool(true),
520 }
521
522 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
523 isActiveText := false
524 toolCalls := make(map[int64]toolCall)
525
526 // Build provider metadata for streaming
527 streamProviderMetadata := ai.ProviderOptions{
528 "openai": make(map[string]any),
529 }
530
531 acc := openai.ChatCompletionAccumulator{}
532 var usage ai.Usage
533 return func(yield func(ai.StreamPart) bool) {
534 if len(warnings) > 0 {
535 if !yield(ai.StreamPart{
536 Type: ai.StreamPartTypeWarnings,
537 Warnings: warnings,
538 }) {
539 return
540 }
541 }
542 for stream.Next() {
543 chunk := stream.Current()
544 acc.AddChunk(chunk)
545 if chunk.Usage.TotalTokens > 0 {
546 // we do this here because the acc does not add prompt details
547 completionTokenDetails := chunk.Usage.CompletionTokensDetails
548 promptTokenDetails := chunk.Usage.PromptTokensDetails
549 usage = ai.Usage{
550 InputTokens: chunk.Usage.PromptTokens,
551 OutputTokens: chunk.Usage.CompletionTokens,
552 TotalTokens: chunk.Usage.TotalTokens,
553 ReasoningTokens: completionTokenDetails.ReasoningTokens,
554 CacheReadTokens: promptTokenDetails.CachedTokens,
555 }
556
557 // Add prediction tokens if available
558 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
559 if completionTokenDetails.AcceptedPredictionTokens > 0 {
560 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
561 }
562 if completionTokenDetails.RejectedPredictionTokens > 0 {
563 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
564 }
565 }
566 }
567 if len(chunk.Choices) == 0 {
568 continue
569 }
570 for _, choice := range chunk.Choices {
571 switch {
572 case choice.Delta.Content != "":
573 if !isActiveText {
574 isActiveText = true
575 if !yield(ai.StreamPart{
576 Type: ai.StreamPartTypeTextStart,
577 ID: "0",
578 }) {
579 return
580 }
581 }
582 if !yield(ai.StreamPart{
583 Type: ai.StreamPartTypeTextDelta,
584 ID: "0",
585 Delta: choice.Delta.Content,
586 }) {
587 return
588 }
589 case len(choice.Delta.ToolCalls) > 0:
590 if isActiveText {
591 isActiveText = false
592 if !yield(ai.StreamPart{
593 Type: ai.StreamPartTypeTextEnd,
594 ID: "0",
595 }) {
596 return
597 }
598 }
599
600 for _, toolCallDelta := range choice.Delta.ToolCalls {
601 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
602 if existingToolCall.hasFinished {
603 continue
604 }
605 if toolCallDelta.Function.Arguments != "" {
606 existingToolCall.arguments += toolCallDelta.Function.Arguments
607 }
608 if !yield(ai.StreamPart{
609 Type: ai.StreamPartTypeToolInputDelta,
610 ID: existingToolCall.id,
611 Delta: toolCallDelta.Function.Arguments,
612 }) {
613 return
614 }
615 toolCalls[toolCallDelta.Index] = existingToolCall
616 if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
617 if !yield(ai.StreamPart{
618 Type: ai.StreamPartTypeToolInputEnd,
619 ID: existingToolCall.id,
620 }) {
621 return
622 }
623
624 if !yield(ai.StreamPart{
625 Type: ai.StreamPartTypeToolCall,
626 ID: existingToolCall.id,
627 ToolCallName: existingToolCall.name,
628 ToolCallInput: existingToolCall.arguments,
629 }) {
630 return
631 }
632 existingToolCall.hasFinished = true
633 toolCalls[toolCallDelta.Index] = existingToolCall
634 }
635
636 } else {
637 // Does not exist
638 var err error
639 if toolCallDelta.Type != "function" {
640 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
641 }
642 if toolCallDelta.ID == "" {
643 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
644 }
645 if toolCallDelta.Function.Name == "" {
646 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
647 }
648 if err != nil {
649 yield(ai.StreamPart{
650 Type: ai.StreamPartTypeError,
651 Error: stream.Err(),
652 })
653 return
654 }
655
656 if !yield(ai.StreamPart{
657 Type: ai.StreamPartTypeToolInputStart,
658 ID: toolCallDelta.ID,
659 ToolCallName: toolCallDelta.Function.Name,
660 }) {
661 return
662 }
663 toolCalls[toolCallDelta.Index] = toolCall{
664 id: toolCallDelta.ID,
665 name: toolCallDelta.Function.Name,
666 arguments: toolCallDelta.Function.Arguments,
667 }
668
669 exTc := toolCalls[toolCallDelta.Index]
670 if exTc.arguments != "" {
671 if !yield(ai.StreamPart{
672 Type: ai.StreamPartTypeToolInputDelta,
673 ID: exTc.id,
674 Delta: exTc.arguments,
675 }) {
676 return
677 }
678 if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
679 if !yield(ai.StreamPart{
680 Type: ai.StreamPartTypeToolInputEnd,
681 ID: toolCallDelta.ID,
682 }) {
683 return
684 }
685
686 if !yield(ai.StreamPart{
687 Type: ai.StreamPartTypeToolCall,
688 ID: exTc.id,
689 ToolCallName: exTc.name,
690 ToolCallInput: exTc.arguments,
691 }) {
692 return
693 }
694 exTc.hasFinished = true
695 toolCalls[toolCallDelta.Index] = exTc
696 }
697 }
698 continue
699 }
700 }
701 }
702 }
703
704 // Check for annotations in the delta's raw JSON
705 for _, choice := range chunk.Choices {
706 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
707 for _, annotation := range annotations {
708 if annotation.Type == "url_citation" {
709 if !yield(ai.StreamPart{
710 Type: ai.StreamPartTypeSource,
711 ID: uuid.NewString(),
712 SourceType: ai.SourceTypeURL,
713 URL: annotation.URLCitation.URL,
714 Title: annotation.URLCitation.Title,
715 }) {
716 return
717 }
718 }
719 }
720 }
721 }
722
723 }
724 err := stream.Err()
725 if err == nil || errors.Is(err, io.EOF) {
726 // finished
727 if isActiveText {
728 isActiveText = false
729 if !yield(ai.StreamPart{
730 Type: ai.StreamPartTypeTextEnd,
731 ID: "0",
732 }) {
733 return
734 }
735 }
736
737 // Add logprobs if available
738 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
739 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
740 }
741
742 // Handle annotations/citations from accumulated response
743 if len(acc.Choices) > 0 {
744 for _, annotation := range acc.Choices[0].Message.Annotations {
745 if annotation.Type == "url_citation" {
746 if !yield(ai.StreamPart{
747 Type: ai.StreamPartTypeSource,
748 ID: uuid.NewString(),
749 SourceType: ai.SourceTypeURL,
750 URL: annotation.URLCitation.URL,
751 Title: annotation.URLCitation.Title,
752 }) {
753 return
754 }
755 }
756 }
757 }
758
759 finishReason := mapOpenAIFinishReason(acc.Choices[0].FinishReason)
760 yield(ai.StreamPart{
761 Type: ai.StreamPartTypeFinish,
762 Usage: usage,
763 FinishReason: finishReason,
764 ProviderMetadata: streamProviderMetadata,
765 })
766 return
767
768 } else {
769 yield(ai.StreamPart{
770 Type: ai.StreamPartTypeError,
771 Error: stream.Err(),
772 })
773 return
774 }
775 }, nil
776}
777
778func mapOpenAIFinishReason(finishReason string) ai.FinishReason {
779 switch finishReason {
780 case "stop":
781 return ai.FinishReasonStop
782 case "length":
783 return ai.FinishReasonLength
784 case "content_filter":
785 return ai.FinishReasonContentFilter
786 case "function_call", "tool_calls":
787 return ai.FinishReasonToolCalls
788 default:
789 return ai.FinishReasonUnknown
790 }
791}
792
793func isReasoningModel(modelID string) bool {
794 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
795}
796
797func isSearchPreviewModel(modelID string) bool {
798 return strings.Contains(modelID, "search-preview")
799}
800
801func supportsFlexProcessing(modelID string) bool {
802 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
803}
804
805func supportsPriorityProcessing(modelID string) bool {
806 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
807 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
808 strings.HasPrefix(modelID, "o4-mini")
809}
810
811func toOpenAITools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAITools []openai.ChatCompletionToolUnionParam, openAIToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
812 for _, tool := range tools {
813 if tool.GetType() == ai.ToolTypeFunction {
814 ft, ok := tool.(ai.FunctionTool)
815 if !ok {
816 continue
817 }
818 openAITools = append(openAITools, openai.ChatCompletionToolUnionParam{
819 OfFunction: &openai.ChatCompletionFunctionToolParam{
820 Function: shared.FunctionDefinitionParam{
821 Name: ft.Name,
822 Description: param.NewOpt(ft.Description),
823 Parameters: openai.FunctionParameters(ft.InputSchema),
824 Strict: param.NewOpt(false),
825 },
826 Type: "function",
827 },
828 })
829 continue
830 }
831
832 // TODO: handle provider tool calls
833 warnings = append(warnings, ai.CallWarning{
834 Type: ai.CallWarningTypeUnsupportedTool,
835 Tool: tool,
836 Message: "tool is not supported",
837 })
838 }
839 if toolChoice == nil {
840 return
841 }
842
843 switch *toolChoice {
844 case ai.ToolChoiceAuto:
845 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
846 OfAuto: param.NewOpt("auto"),
847 }
848 case ai.ToolChoiceNone:
849 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
850 OfAuto: param.NewOpt("none"),
851 }
852 default:
853 openAIToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
854 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
855 Type: "function",
856 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
857 Name: string(*toolChoice),
858 },
859 },
860 }
861 }
862 return
863}
864
865func toOpenAIPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
866 var messages []openai.ChatCompletionMessageParamUnion
867 var warnings []ai.CallWarning
868 for _, msg := range prompt {
869 switch msg.Role {
870 case ai.MessageRoleSystem:
871 var systemPromptParts []string
872 for _, c := range msg.Content {
873 if c.GetType() != ai.ContentTypeText {
874 warnings = append(warnings, ai.CallWarning{
875 Type: ai.CallWarningTypeOther,
876 Message: "system prompt can only have text content",
877 })
878 continue
879 }
880 textPart, ok := ai.AsContentType[ai.TextPart](c)
881 if !ok {
882 warnings = append(warnings, ai.CallWarning{
883 Type: ai.CallWarningTypeOther,
884 Message: "system prompt text part does not have the right type",
885 })
886 continue
887 }
888 text := textPart.Text
889 if strings.TrimSpace(text) != "" {
890 systemPromptParts = append(systemPromptParts, textPart.Text)
891 }
892 }
893 if len(systemPromptParts) == 0 {
894 warnings = append(warnings, ai.CallWarning{
895 Type: ai.CallWarningTypeOther,
896 Message: "system prompt has no text parts",
897 })
898 continue
899 }
900 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
901 case ai.MessageRoleUser:
902 // simple user message just text content
903 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
904 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
905 if !ok {
906 warnings = append(warnings, ai.CallWarning{
907 Type: ai.CallWarningTypeOther,
908 Message: "user message text part does not have the right type",
909 })
910 continue
911 }
912 messages = append(messages, openai.UserMessage(textPart.Text))
913 continue
914 }
915 // text content and attachments
916 // for now we only support image content later we need to check
917 // TODO: add the supported media types to the language model so we
918 // can use that to validate the data here.
919 var content []openai.ChatCompletionContentPartUnionParam
920 for _, c := range msg.Content {
921 switch c.GetType() {
922 case ai.ContentTypeText:
923 textPart, ok := ai.AsContentType[ai.TextPart](c)
924 if !ok {
925 warnings = append(warnings, ai.CallWarning{
926 Type: ai.CallWarningTypeOther,
927 Message: "user message text part does not have the right type",
928 })
929 continue
930 }
931 content = append(content, openai.ChatCompletionContentPartUnionParam{
932 OfText: &openai.ChatCompletionContentPartTextParam{
933 Text: textPart.Text,
934 },
935 })
936 case ai.ContentTypeFile:
937 filePart, ok := ai.AsContentType[ai.FilePart](c)
938 if !ok {
939 warnings = append(warnings, ai.CallWarning{
940 Type: ai.CallWarningTypeOther,
941 Message: "user message file part does not have the right type",
942 })
943 continue
944 }
945
946 switch {
947 case strings.HasPrefix(filePart.MediaType, "image/"):
948 // Handle image files
949 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
950 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
951 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
952
953 // Check for provider-specific options like image detail
954 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
955 if detail, ok := providerOptions["imageDetail"].(string); ok {
956 imageURL.Detail = detail
957 }
958 }
959
960 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
961 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
962
963 case filePart.MediaType == "audio/wav":
964 // Handle WAV audio files
965 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
966 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
967 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
968 Data: base64Encoded,
969 Format: "wav",
970 },
971 }
972 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
973
974 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
975 // Handle MP3 audio files
976 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
977 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
978 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
979 Data: base64Encoded,
980 Format: "mp3",
981 },
982 }
983 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
984
985 case filePart.MediaType == "application/pdf":
986 // Handle PDF files
987 dataStr := string(filePart.Data)
988
989 // Check if data looks like a file ID (starts with "file-")
990 if strings.HasPrefix(dataStr, "file-") {
991 fileBlock := openai.ChatCompletionContentPartFileParam{
992 File: openai.ChatCompletionContentPartFileFileParam{
993 FileID: param.NewOpt(dataStr),
994 },
995 }
996 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
997 } else {
998 // Handle as base64 data
999 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1000 data := "data:application/pdf;base64," + base64Encoded
1001
1002 filename := filePart.Filename
1003 if filename == "" {
1004 // Generate default filename based on content index
1005 filename = fmt.Sprintf("part-%d.pdf", len(content))
1006 }
1007
1008 fileBlock := openai.ChatCompletionContentPartFileParam{
1009 File: openai.ChatCompletionContentPartFileFileParam{
1010 Filename: param.NewOpt(filename),
1011 FileData: param.NewOpt(data),
1012 },
1013 }
1014 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1015 }
1016
1017 default:
1018 warnings = append(warnings, ai.CallWarning{
1019 Type: ai.CallWarningTypeOther,
1020 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1021 })
1022 }
1023 }
1024 }
1025 messages = append(messages, openai.UserMessage(content))
1026 case ai.MessageRoleAssistant:
1027 // simple assistant message just text content
1028 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1029 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1030 if !ok {
1031 warnings = append(warnings, ai.CallWarning{
1032 Type: ai.CallWarningTypeOther,
1033 Message: "assistant message text part does not have the right type",
1034 })
1035 continue
1036 }
1037 messages = append(messages, openai.AssistantMessage(textPart.Text))
1038 continue
1039 }
1040 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1041 Role: "assistant",
1042 }
1043 for _, c := range msg.Content {
1044 switch c.GetType() {
1045 case ai.ContentTypeText:
1046 textPart, ok := ai.AsContentType[ai.TextPart](c)
1047 if !ok {
1048 warnings = append(warnings, ai.CallWarning{
1049 Type: ai.CallWarningTypeOther,
1050 Message: "assistant message text part does not have the right type",
1051 })
1052 continue
1053 }
1054 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1055 OfString: param.NewOpt(textPart.Text),
1056 }
1057 case ai.ContentTypeToolCall:
1058 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1059 if !ok {
1060 warnings = append(warnings, ai.CallWarning{
1061 Type: ai.CallWarningTypeOther,
1062 Message: "assistant message tool part does not have the right type",
1063 })
1064 continue
1065 }
1066 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1067 openai.ChatCompletionMessageToolCallUnionParam{
1068 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1069 ID: toolCallPart.ToolCallID,
1070 Type: "function",
1071 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1072 Name: toolCallPart.ToolName,
1073 Arguments: toolCallPart.Input,
1074 },
1075 },
1076 })
1077 }
1078 }
1079 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1080 OfAssistant: &assistantMsg,
1081 })
1082 case ai.MessageRoleTool:
1083 for _, c := range msg.Content {
1084 if c.GetType() != ai.ContentTypeToolResult {
1085 warnings = append(warnings, ai.CallWarning{
1086 Type: ai.CallWarningTypeOther,
1087 Message: "tool message can only have tool result content",
1088 })
1089 continue
1090 }
1091
1092 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1093 if !ok {
1094 warnings = append(warnings, ai.CallWarning{
1095 Type: ai.CallWarningTypeOther,
1096 Message: "tool message result part does not have the right type",
1097 })
1098 continue
1099 }
1100
1101 switch toolResultPart.Output.GetType() {
1102 case ai.ToolResultContentTypeText:
1103 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1104 if !ok {
1105 warnings = append(warnings, ai.CallWarning{
1106 Type: ai.CallWarningTypeOther,
1107 Message: "tool result output does not have the right type",
1108 })
1109 continue
1110 }
1111 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1112 case ai.ToolResultContentTypeError:
1113 // TODO: check if better handling is needed
1114 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1115 if !ok {
1116 warnings = append(warnings, ai.CallWarning{
1117 Type: ai.CallWarningTypeOther,
1118 Message: "tool result output does not have the right type",
1119 })
1120 continue
1121 }
1122 messages = append(messages, openai.ToolMessage(output.Error, toolResultPart.ToolCallID))
1123 }
1124 }
1125 }
1126 }
1127 return messages, warnings
1128}
1129
1130// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta
1131func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1132 var annotations []openai.ChatCompletionMessageAnnotation
1133
1134 // Parse the raw JSON to extract annotations
1135 var deltaData map[string]interface{}
1136 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1137 return annotations
1138 }
1139
1140 // Check if annotations exist in the delta
1141 if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
1142 for _, annotationData := range annotationsData {
1143 if annotationMap, ok := annotationData.(map[string]interface{}); ok {
1144 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1145 if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
1146 annotation := openai.ChatCompletionMessageAnnotation{
1147 Type: "url_citation",
1148 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1149 URL: urlCitationData["url"].(string),
1150 Title: urlCitationData["title"].(string),
1151 },
1152 }
1153 annotations = append(annotations, annotation)
1154 }
1155 }
1156 }
1157 }
1158 }
1159
1160 return annotations
1161}