1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/ai"
15 "github.com/charmbracelet/ai/internal/jsonext"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23type provider struct {
24 options options
25}
26
27type options struct {
28 baseURL string
29 apiKey string
30 organization string
31 project string
32 name string
33 headers map[string]string
34 client option.HTTPClient
35}
36
37type Option = func(*options)
38
39func New(opts ...Option) ai.Provider {
40 options := options{
41 headers: map[string]string{},
42 }
43 for _, o := range opts {
44 o(&options)
45 }
46
47 options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
48 options.name = cmp.Or(options.name, "openai")
49
50 if options.organization != "" {
51 options.headers["OpenAi-Organization"] = options.organization
52 }
53 if options.project != "" {
54 options.headers["OpenAi-Project"] = options.project
55 }
56
57 return &provider{options: options}
58}
59
60func WithBaseURL(baseURL string) Option {
61 return func(o *options) {
62 o.baseURL = baseURL
63 }
64}
65
66func WithAPIKey(apiKey string) Option {
67 return func(o *options) {
68 o.apiKey = apiKey
69 }
70}
71
72func WithOrganization(organization string) Option {
73 return func(o *options) {
74 o.organization = organization
75 }
76}
77
78func WithProject(project string) Option {
79 return func(o *options) {
80 o.project = project
81 }
82}
83
84func WithName(name string) Option {
85 return func(o *options) {
86 o.name = name
87 }
88}
89
90func WithHeaders(headers map[string]string) Option {
91 return func(o *options) {
92 maps.Copy(o.headers, headers)
93 }
94}
95
96func WithHTTPClient(client option.HTTPClient) Option {
97 return func(o *options) {
98 o.client = client
99 }
100}
101
102// LanguageModel implements ai.Provider.
103func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
104 openaiClientOptions := []option.RequestOption{}
105 if o.options.apiKey != "" {
106 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
107 }
108 if o.options.baseURL != "" {
109 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
110 }
111
112 for key, value := range o.options.headers {
113 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
114 }
115
116 if o.options.client != nil {
117 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
118 }
119
120 return languageModel{
121 modelID: modelID,
122 provider: fmt.Sprintf("%s.chat", o.options.name),
123 options: o.options,
124 client: openai.NewClient(openaiClientOptions...),
125 }, nil
126}
127
128type languageModel struct {
129 provider string
130 modelID string
131 client openai.Client
132 options options
133}
134
135// Model implements ai.LanguageModel.
136func (o languageModel) Model() string {
137 return o.modelID
138}
139
140// Provider implements ai.LanguageModel.
141func (o languageModel) Provider() string {
142 return o.provider
143}
144
145func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
146 params := &openai.ChatCompletionNewParams{}
147 messages, warnings := toPrompt(call.Prompt)
148 providerOptions := &providerOptions{}
149 if v, ok := call.ProviderOptions["openai"]; ok {
150 err := ai.ParseOptions(v, providerOptions)
151 if err != nil {
152 return nil, nil, err
153 }
154 }
155 if call.TopK != nil {
156 warnings = append(warnings, ai.CallWarning{
157 Type: ai.CallWarningTypeUnsupportedSetting,
158 Setting: "top_k",
159 })
160 }
161 params.Messages = messages
162 params.Model = o.modelID
163 if providerOptions.LogitBias != nil {
164 params.LogitBias = providerOptions.LogitBias
165 }
166 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
167 providerOptions.LogProbs = nil
168 }
169 if providerOptions.LogProbs != nil {
170 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
171 }
172 if providerOptions.TopLogProbs != nil {
173 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
174 }
175 if providerOptions.User != nil {
176 params.User = param.NewOpt(*providerOptions.User)
177 }
178 if providerOptions.ParallelToolCalls != nil {
179 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
180 }
181
182 if call.MaxOutputTokens != nil {
183 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
184 }
185 if call.Temperature != nil {
186 params.Temperature = param.NewOpt(*call.Temperature)
187 }
188 if call.TopP != nil {
189 params.TopP = param.NewOpt(*call.TopP)
190 }
191 if call.FrequencyPenalty != nil {
192 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
193 }
194 if call.PresencePenalty != nil {
195 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
196 }
197
198 if providerOptions.MaxCompletionTokens != nil {
199 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
200 }
201
202 if providerOptions.TextVerbosity != nil {
203 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
204 }
205 if providerOptions.Prediction != nil {
206 // Convert map[string]any to ChatCompletionPredictionContentParam
207 if content, ok := providerOptions.Prediction["content"]; ok {
208 if contentStr, ok := content.(string); ok {
209 params.Prediction = openai.ChatCompletionPredictionContentParam{
210 Content: openai.ChatCompletionPredictionContentContentUnionParam{
211 OfString: param.NewOpt(contentStr),
212 },
213 }
214 }
215 }
216 }
217 if providerOptions.Store != nil {
218 params.Store = param.NewOpt(*providerOptions.Store)
219 }
220 if providerOptions.Metadata != nil {
221 // Convert map[string]any to map[string]string
222 metadata := make(map[string]string)
223 for k, v := range providerOptions.Metadata {
224 if str, ok := v.(string); ok {
225 metadata[k] = str
226 }
227 }
228 params.Metadata = metadata
229 }
230 if providerOptions.PromptCacheKey != nil {
231 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
232 }
233 if providerOptions.SafetyIdentifier != nil {
234 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
235 }
236 if providerOptions.ServiceTier != nil {
237 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
238 }
239
240 if providerOptions.ReasoningEffort != nil {
241 switch *providerOptions.ReasoningEffort {
242 case reasoningEffortMinimal:
243 params.ReasoningEffort = shared.ReasoningEffortMinimal
244 case reasoningEffortLow:
245 params.ReasoningEffort = shared.ReasoningEffortLow
246 case reasoningEffortMedium:
247 params.ReasoningEffort = shared.ReasoningEffortMedium
248 case reasoningEffortHigh:
249 params.ReasoningEffort = shared.ReasoningEffortHigh
250 default:
251 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
252 }
253 }
254
255 if isReasoningModel(o.modelID) {
256 // remove unsupported settings for reasoning models
257 // see https://platform.openai.com/docs/guides/reasoning#limitations
258 if call.Temperature != nil {
259 params.Temperature = param.Opt[float64]{}
260 warnings = append(warnings, ai.CallWarning{
261 Type: ai.CallWarningTypeUnsupportedSetting,
262 Setting: "temperature",
263 Details: "temperature is not supported for reasoning models",
264 })
265 }
266 if call.TopP != nil {
267 params.TopP = param.Opt[float64]{}
268 warnings = append(warnings, ai.CallWarning{
269 Type: ai.CallWarningTypeUnsupportedSetting,
270 Setting: "TopP",
271 Details: "TopP is not supported for reasoning models",
272 })
273 }
274 if call.FrequencyPenalty != nil {
275 params.FrequencyPenalty = param.Opt[float64]{}
276 warnings = append(warnings, ai.CallWarning{
277 Type: ai.CallWarningTypeUnsupportedSetting,
278 Setting: "FrequencyPenalty",
279 Details: "FrequencyPenalty is not supported for reasoning models",
280 })
281 }
282 if call.PresencePenalty != nil {
283 params.PresencePenalty = param.Opt[float64]{}
284 warnings = append(warnings, ai.CallWarning{
285 Type: ai.CallWarningTypeUnsupportedSetting,
286 Setting: "PresencePenalty",
287 Details: "PresencePenalty is not supported for reasoning models",
288 })
289 }
290 if providerOptions.LogitBias != nil {
291 params.LogitBias = nil
292 warnings = append(warnings, ai.CallWarning{
293 Type: ai.CallWarningTypeUnsupportedSetting,
294 Setting: "LogitBias",
295 Message: "LogitBias is not supported for reasoning models",
296 })
297 }
298 if providerOptions.LogProbs != nil {
299 params.Logprobs = param.Opt[bool]{}
300 warnings = append(warnings, ai.CallWarning{
301 Type: ai.CallWarningTypeUnsupportedSetting,
302 Setting: "Logprobs",
303 Message: "Logprobs is not supported for reasoning models",
304 })
305 }
306 if providerOptions.TopLogProbs != nil {
307 params.TopLogprobs = param.Opt[int64]{}
308 warnings = append(warnings, ai.CallWarning{
309 Type: ai.CallWarningTypeUnsupportedSetting,
310 Setting: "TopLogprobs",
311 Message: "TopLogprobs is not supported for reasoning models",
312 })
313 }
314
315 // reasoning models use max_completion_tokens instead of max_tokens
316 if call.MaxOutputTokens != nil {
317 if providerOptions.MaxCompletionTokens == nil {
318 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
319 }
320 params.MaxTokens = param.Opt[int64]{}
321 }
322 }
323
324 // Handle search preview models
325 if isSearchPreviewModel(o.modelID) {
326 if call.Temperature != nil {
327 params.Temperature = param.Opt[float64]{}
328 warnings = append(warnings, ai.CallWarning{
329 Type: ai.CallWarningTypeUnsupportedSetting,
330 Setting: "temperature",
331 Details: "temperature is not supported for the search preview models and has been removed.",
332 })
333 }
334 }
335
336 // Handle service tier validation
337 if providerOptions.ServiceTier != nil {
338 serviceTier := *providerOptions.ServiceTier
339 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
340 params.ServiceTier = ""
341 warnings = append(warnings, ai.CallWarning{
342 Type: ai.CallWarningTypeUnsupportedSetting,
343 Setting: "ServiceTier",
344 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
345 })
346 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
347 params.ServiceTier = ""
348 warnings = append(warnings, ai.CallWarning{
349 Type: ai.CallWarningTypeUnsupportedSetting,
350 Setting: "ServiceTier",
351 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
352 })
353 }
354 }
355
356 if len(call.Tools) > 0 {
357 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
358 params.Tools = tools
359 if toolChoice != nil {
360 params.ToolChoice = *toolChoice
361 }
362 warnings = append(warnings, toolWarnings...)
363 }
364 return params, warnings, nil
365}
366
367func (o languageModel) handleError(err error) error {
368 var apiErr *openai.Error
369 if errors.As(err, &apiErr) {
370 requestDump := apiErr.DumpRequest(true)
371 responseDump := apiErr.DumpResponse(true)
372 headers := map[string]string{}
373 for k, h := range apiErr.Response.Header {
374 v := h[len(h)-1]
375 headers[strings.ToLower(k)] = v
376 }
377 return ai.NewAPICallError(
378 apiErr.Message,
379 apiErr.Request.URL.String(),
380 string(requestDump),
381 apiErr.StatusCode,
382 headers,
383 string(responseDump),
384 apiErr,
385 false,
386 )
387 }
388 return err
389}
390
391// Generate implements ai.LanguageModel.
392func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
393 params, warnings, err := o.prepareParams(call)
394 if err != nil {
395 return nil, err
396 }
397 response, err := o.client.Chat.Completions.New(ctx, *params)
398 if err != nil {
399 return nil, o.handleError(err)
400 }
401
402 if len(response.Choices) == 0 {
403 return nil, errors.New("no response generated")
404 }
405 choice := response.Choices[0]
406 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
407 text := choice.Message.Content
408 if text != "" {
409 content = append(content, ai.TextContent{
410 Text: text,
411 })
412 }
413
414 for _, tc := range choice.Message.ToolCalls {
415 toolCallID := tc.ID
416 if toolCallID == "" {
417 toolCallID = uuid.NewString()
418 }
419 content = append(content, ai.ToolCallContent{
420 ProviderExecuted: false, // TODO: update when handling other tools
421 ToolCallID: toolCallID,
422 ToolName: tc.Function.Name,
423 Input: tc.Function.Arguments,
424 })
425 }
426 // Handle annotations/citations
427 for _, annotation := range choice.Message.Annotations {
428 if annotation.Type == "url_citation" {
429 content = append(content, ai.SourceContent{
430 SourceType: ai.SourceTypeURL,
431 ID: uuid.NewString(),
432 URL: annotation.URLCitation.URL,
433 Title: annotation.URLCitation.Title,
434 })
435 }
436 }
437
438 completionTokenDetails := response.Usage.CompletionTokensDetails
439 promptTokenDetails := response.Usage.PromptTokensDetails
440
441 // Build provider metadata
442 providerMetadata := ai.ProviderMetadata{
443 "openai": make(map[string]any),
444 }
445
446 // Add logprobs if available
447 if len(choice.Logprobs.Content) > 0 {
448 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
449 }
450
451 // Add prediction tokens if available
452 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
453 if completionTokenDetails.AcceptedPredictionTokens > 0 {
454 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
455 }
456 if completionTokenDetails.RejectedPredictionTokens > 0 {
457 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
458 }
459 }
460
461 return &ai.Response{
462 Content: content,
463 Usage: ai.Usage{
464 InputTokens: response.Usage.PromptTokens,
465 OutputTokens: response.Usage.CompletionTokens,
466 TotalTokens: response.Usage.TotalTokens,
467 ReasoningTokens: completionTokenDetails.ReasoningTokens,
468 CacheReadTokens: promptTokenDetails.CachedTokens,
469 },
470 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
471 ProviderMetadata: providerMetadata,
472 Warnings: warnings,
473 }, nil
474}
475
476type toolCall struct {
477 id string
478 name string
479 arguments string
480 hasFinished bool
481}
482
483// Stream implements ai.LanguageModel.
484func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
485 params, warnings, err := o.prepareParams(call)
486 if err != nil {
487 return nil, err
488 }
489
490 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
491 IncludeUsage: openai.Bool(true),
492 }
493
494 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
495 isActiveText := false
496 toolCalls := make(map[int64]toolCall)
497
498 // Build provider metadata for streaming
499 streamProviderMetadata := ai.ProviderMetadata{
500 "openai": make(map[string]any),
501 }
502
503 acc := openai.ChatCompletionAccumulator{}
504 var usage ai.Usage
505 return func(yield func(ai.StreamPart) bool) {
506 if len(warnings) > 0 {
507 if !yield(ai.StreamPart{
508 Type: ai.StreamPartTypeWarnings,
509 Warnings: warnings,
510 }) {
511 return
512 }
513 }
514 for stream.Next() {
515 chunk := stream.Current()
516 acc.AddChunk(chunk)
517 if chunk.Usage.TotalTokens > 0 {
518 // we do this here because the acc does not add prompt details
519 completionTokenDetails := chunk.Usage.CompletionTokensDetails
520 promptTokenDetails := chunk.Usage.PromptTokensDetails
521 usage = ai.Usage{
522 InputTokens: chunk.Usage.PromptTokens,
523 OutputTokens: chunk.Usage.CompletionTokens,
524 TotalTokens: chunk.Usage.TotalTokens,
525 ReasoningTokens: completionTokenDetails.ReasoningTokens,
526 CacheReadTokens: promptTokenDetails.CachedTokens,
527 }
528
529 // Add prediction tokens if available
530 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
531 if completionTokenDetails.AcceptedPredictionTokens > 0 {
532 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
533 }
534 if completionTokenDetails.RejectedPredictionTokens > 0 {
535 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
536 }
537 }
538 }
539 if len(chunk.Choices) == 0 {
540 continue
541 }
542 for _, choice := range chunk.Choices {
543 switch {
544 case choice.Delta.Content != "":
545 if !isActiveText {
546 isActiveText = true
547 if !yield(ai.StreamPart{
548 Type: ai.StreamPartTypeTextStart,
549 ID: "0",
550 }) {
551 return
552 }
553 }
554 if !yield(ai.StreamPart{
555 Type: ai.StreamPartTypeTextDelta,
556 ID: "0",
557 Delta: choice.Delta.Content,
558 }) {
559 return
560 }
561 case len(choice.Delta.ToolCalls) > 0:
562 if isActiveText {
563 isActiveText = false
564 if !yield(ai.StreamPart{
565 Type: ai.StreamPartTypeTextEnd,
566 ID: "0",
567 }) {
568 return
569 }
570 }
571
572 for _, toolCallDelta := range choice.Delta.ToolCalls {
573 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
574 if existingToolCall.hasFinished {
575 continue
576 }
577 if toolCallDelta.Function.Arguments != "" {
578 existingToolCall.arguments += toolCallDelta.Function.Arguments
579 }
580 if !yield(ai.StreamPart{
581 Type: ai.StreamPartTypeToolInputDelta,
582 ID: existingToolCall.id,
583 Delta: toolCallDelta.Function.Arguments,
584 }) {
585 return
586 }
587 toolCalls[toolCallDelta.Index] = existingToolCall
588 if jsonext.IsValidJSON(existingToolCall.arguments) {
589 if !yield(ai.StreamPart{
590 Type: ai.StreamPartTypeToolInputEnd,
591 ID: existingToolCall.id,
592 }) {
593 return
594 }
595
596 if !yield(ai.StreamPart{
597 Type: ai.StreamPartTypeToolCall,
598 ID: existingToolCall.id,
599 ToolCallName: existingToolCall.name,
600 ToolCallInput: existingToolCall.arguments,
601 }) {
602 return
603 }
604 existingToolCall.hasFinished = true
605 toolCalls[toolCallDelta.Index] = existingToolCall
606 }
607 } else {
608 // Does not exist
609 var err error
610 if toolCallDelta.Type != "function" {
611 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
612 }
613 if toolCallDelta.ID == "" {
614 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
615 }
616 if toolCallDelta.Function.Name == "" {
617 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
618 }
619 if err != nil {
620 yield(ai.StreamPart{
621 Type: ai.StreamPartTypeError,
622 Error: o.handleError(stream.Err()),
623 })
624 return
625 }
626
627 if !yield(ai.StreamPart{
628 Type: ai.StreamPartTypeToolInputStart,
629 ID: toolCallDelta.ID,
630 ToolCallName: toolCallDelta.Function.Name,
631 }) {
632 return
633 }
634 toolCalls[toolCallDelta.Index] = toolCall{
635 id: toolCallDelta.ID,
636 name: toolCallDelta.Function.Name,
637 arguments: toolCallDelta.Function.Arguments,
638 }
639
640 exTc := toolCalls[toolCallDelta.Index]
641 if exTc.arguments != "" {
642 if !yield(ai.StreamPart{
643 Type: ai.StreamPartTypeToolInputDelta,
644 ID: exTc.id,
645 Delta: exTc.arguments,
646 }) {
647 return
648 }
649 if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
650 if !yield(ai.StreamPart{
651 Type: ai.StreamPartTypeToolInputEnd,
652 ID: toolCallDelta.ID,
653 }) {
654 return
655 }
656
657 if !yield(ai.StreamPart{
658 Type: ai.StreamPartTypeToolCall,
659 ID: exTc.id,
660 ToolCallName: exTc.name,
661 ToolCallInput: exTc.arguments,
662 }) {
663 return
664 }
665 exTc.hasFinished = true
666 toolCalls[toolCallDelta.Index] = exTc
667 }
668 }
669 continue
670 }
671 }
672 }
673 }
674
675 // Check for annotations in the delta's raw JSON
676 for _, choice := range chunk.Choices {
677 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
678 for _, annotation := range annotations {
679 if annotation.Type == "url_citation" {
680 if !yield(ai.StreamPart{
681 Type: ai.StreamPartTypeSource,
682 ID: uuid.NewString(),
683 SourceType: ai.SourceTypeURL,
684 URL: annotation.URLCitation.URL,
685 Title: annotation.URLCitation.Title,
686 }) {
687 return
688 }
689 }
690 }
691 }
692 }
693 }
694 err := stream.Err()
695 if err == nil || errors.Is(err, io.EOF) {
696 // finished
697 if isActiveText {
698 isActiveText = false
699 if !yield(ai.StreamPart{
700 Type: ai.StreamPartTypeTextEnd,
701 ID: "0",
702 }) {
703 return
704 }
705 }
706
707 // Add logprobs if available
708 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
709 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
710 }
711
712 // Handle annotations/citations from accumulated response
713 if len(acc.Choices) > 0 {
714 for _, annotation := range acc.Choices[0].Message.Annotations {
715 if annotation.Type == "url_citation" {
716 if !yield(ai.StreamPart{
717 Type: ai.StreamPartTypeSource,
718 ID: acc.ID,
719 SourceType: ai.SourceTypeURL,
720 URL: annotation.URLCitation.URL,
721 Title: annotation.URLCitation.Title,
722 }) {
723 return
724 }
725 }
726 }
727 }
728
729 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
730 yield(ai.StreamPart{
731 Type: ai.StreamPartTypeFinish,
732 Usage: usage,
733 FinishReason: finishReason,
734 ProviderMetadata: streamProviderMetadata,
735 })
736 return
737 } else {
738 yield(ai.StreamPart{
739 Type: ai.StreamPartTypeError,
740 Error: o.handleError(err),
741 })
742 return
743 }
744 }, nil
745}
746
747func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
748 switch finishReason {
749 case "stop":
750 return ai.FinishReasonStop
751 case "length":
752 return ai.FinishReasonLength
753 case "content_filter":
754 return ai.FinishReasonContentFilter
755 case "function_call", "tool_calls":
756 return ai.FinishReasonToolCalls
757 default:
758 return ai.FinishReasonUnknown
759 }
760}
761
762func isReasoningModel(modelID string) bool {
763 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
764}
765
766func isSearchPreviewModel(modelID string) bool {
767 return strings.Contains(modelID, "search-preview")
768}
769
770func supportsFlexProcessing(modelID string) bool {
771 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
772}
773
774func supportsPriorityProcessing(modelID string) bool {
775 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
776 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
777 strings.HasPrefix(modelID, "o4-mini")
778}
779
780func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
781 for _, tool := range tools {
782 if tool.GetType() == ai.ToolTypeFunction {
783 ft, ok := tool.(ai.FunctionTool)
784 if !ok {
785 continue
786 }
787 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
788 OfFunction: &openai.ChatCompletionFunctionToolParam{
789 Function: shared.FunctionDefinitionParam{
790 Name: ft.Name,
791 Description: param.NewOpt(ft.Description),
792 Parameters: openai.FunctionParameters(ft.InputSchema),
793 Strict: param.NewOpt(false),
794 },
795 Type: "function",
796 },
797 })
798 continue
799 }
800
801 // TODO: handle provider tool calls
802 warnings = append(warnings, ai.CallWarning{
803 Type: ai.CallWarningTypeUnsupportedTool,
804 Tool: tool,
805 Message: "tool is not supported",
806 })
807 }
808 if toolChoice == nil {
809 return openAiTools, openAiToolChoice, warnings
810 }
811
812 switch *toolChoice {
813 case ai.ToolChoiceAuto:
814 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
815 OfAuto: param.NewOpt("auto"),
816 }
817 case ai.ToolChoiceNone:
818 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
819 OfAuto: param.NewOpt("none"),
820 }
821 default:
822 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
823 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
824 Type: "function",
825 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
826 Name: string(*toolChoice),
827 },
828 },
829 }
830 }
831 return openAiTools, openAiToolChoice, warnings
832}
833
834func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
835 var messages []openai.ChatCompletionMessageParamUnion
836 var warnings []ai.CallWarning
837 for _, msg := range prompt {
838 switch msg.Role {
839 case ai.MessageRoleSystem:
840 var systemPromptParts []string
841 for _, c := range msg.Content {
842 if c.GetType() != ai.ContentTypeText {
843 warnings = append(warnings, ai.CallWarning{
844 Type: ai.CallWarningTypeOther,
845 Message: "system prompt can only have text content",
846 })
847 continue
848 }
849 textPart, ok := ai.AsContentType[ai.TextPart](c)
850 if !ok {
851 warnings = append(warnings, ai.CallWarning{
852 Type: ai.CallWarningTypeOther,
853 Message: "system prompt text part does not have the right type",
854 })
855 continue
856 }
857 text := textPart.Text
858 if strings.TrimSpace(text) != "" {
859 systemPromptParts = append(systemPromptParts, textPart.Text)
860 }
861 }
862 if len(systemPromptParts) == 0 {
863 warnings = append(warnings, ai.CallWarning{
864 Type: ai.CallWarningTypeOther,
865 Message: "system prompt has no text parts",
866 })
867 continue
868 }
869 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
870 case ai.MessageRoleUser:
871 // simple user message just text content
872 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
873 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
874 if !ok {
875 warnings = append(warnings, ai.CallWarning{
876 Type: ai.CallWarningTypeOther,
877 Message: "user message text part does not have the right type",
878 })
879 continue
880 }
881 messages = append(messages, openai.UserMessage(textPart.Text))
882 continue
883 }
884 // text content and attachments
885 // for now we only support image content later we need to check
886 // TODO: add the supported media types to the language model so we
887 // can use that to validate the data here.
888 var content []openai.ChatCompletionContentPartUnionParam
889 for _, c := range msg.Content {
890 switch c.GetType() {
891 case ai.ContentTypeText:
892 textPart, ok := ai.AsContentType[ai.TextPart](c)
893 if !ok {
894 warnings = append(warnings, ai.CallWarning{
895 Type: ai.CallWarningTypeOther,
896 Message: "user message text part does not have the right type",
897 })
898 continue
899 }
900 content = append(content, openai.ChatCompletionContentPartUnionParam{
901 OfText: &openai.ChatCompletionContentPartTextParam{
902 Text: textPart.Text,
903 },
904 })
905 case ai.ContentTypeFile:
906 filePart, ok := ai.AsContentType[ai.FilePart](c)
907 if !ok {
908 warnings = append(warnings, ai.CallWarning{
909 Type: ai.CallWarningTypeOther,
910 Message: "user message file part does not have the right type",
911 })
912 continue
913 }
914
915 switch {
916 case strings.HasPrefix(filePart.MediaType, "image/"):
917 // Handle image files
918 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
919 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
920 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
921
922 // Check for provider-specific options like image detail
923 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
924 if detail, ok := providerOptions["imageDetail"].(string); ok {
925 imageURL.Detail = detail
926 }
927 }
928
929 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
930 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
931
932 case filePart.MediaType == "audio/wav":
933 // Handle WAV audio files
934 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
935 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
936 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
937 Data: base64Encoded,
938 Format: "wav",
939 },
940 }
941 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
942
943 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
944 // Handle MP3 audio files
945 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
946 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
947 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
948 Data: base64Encoded,
949 Format: "mp3",
950 },
951 }
952 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
953
954 case filePart.MediaType == "application/pdf":
955 // Handle PDF files
956 dataStr := string(filePart.Data)
957
958 // Check if data looks like a file ID (starts with "file-")
959 if strings.HasPrefix(dataStr, "file-") {
960 fileBlock := openai.ChatCompletionContentPartFileParam{
961 File: openai.ChatCompletionContentPartFileFileParam{
962 FileID: param.NewOpt(dataStr),
963 },
964 }
965 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
966 } else {
967 // Handle as base64 data
968 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
969 data := "data:application/pdf;base64," + base64Encoded
970
971 filename := filePart.Filename
972 if filename == "" {
973 // Generate default filename based on content index
974 filename = fmt.Sprintf("part-%d.pdf", len(content))
975 }
976
977 fileBlock := openai.ChatCompletionContentPartFileParam{
978 File: openai.ChatCompletionContentPartFileFileParam{
979 Filename: param.NewOpt(filename),
980 FileData: param.NewOpt(data),
981 },
982 }
983 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
984 }
985
986 default:
987 warnings = append(warnings, ai.CallWarning{
988 Type: ai.CallWarningTypeOther,
989 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
990 })
991 }
992 }
993 }
994 messages = append(messages, openai.UserMessage(content))
995 case ai.MessageRoleAssistant:
996 // simple assistant message just text content
997 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
998 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
999 if !ok {
1000 warnings = append(warnings, ai.CallWarning{
1001 Type: ai.CallWarningTypeOther,
1002 Message: "assistant message text part does not have the right type",
1003 })
1004 continue
1005 }
1006 messages = append(messages, openai.AssistantMessage(textPart.Text))
1007 continue
1008 }
1009 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1010 Role: "assistant",
1011 }
1012 for _, c := range msg.Content {
1013 switch c.GetType() {
1014 case ai.ContentTypeText:
1015 textPart, ok := ai.AsContentType[ai.TextPart](c)
1016 if !ok {
1017 warnings = append(warnings, ai.CallWarning{
1018 Type: ai.CallWarningTypeOther,
1019 Message: "assistant message text part does not have the right type",
1020 })
1021 continue
1022 }
1023 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1024 OfString: param.NewOpt(textPart.Text),
1025 }
1026 case ai.ContentTypeToolCall:
1027 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1028 if !ok {
1029 warnings = append(warnings, ai.CallWarning{
1030 Type: ai.CallWarningTypeOther,
1031 Message: "assistant message tool part does not have the right type",
1032 })
1033 continue
1034 }
1035 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1036 openai.ChatCompletionMessageToolCallUnionParam{
1037 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1038 ID: toolCallPart.ToolCallID,
1039 Type: "function",
1040 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1041 Name: toolCallPart.ToolName,
1042 Arguments: toolCallPart.Input,
1043 },
1044 },
1045 })
1046 }
1047 }
1048 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1049 OfAssistant: &assistantMsg,
1050 })
1051 case ai.MessageRoleTool:
1052 for _, c := range msg.Content {
1053 if c.GetType() != ai.ContentTypeToolResult {
1054 warnings = append(warnings, ai.CallWarning{
1055 Type: ai.CallWarningTypeOther,
1056 Message: "tool message can only have tool result content",
1057 })
1058 continue
1059 }
1060
1061 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1062 if !ok {
1063 warnings = append(warnings, ai.CallWarning{
1064 Type: ai.CallWarningTypeOther,
1065 Message: "tool message result part does not have the right type",
1066 })
1067 continue
1068 }
1069
1070 switch toolResultPart.Output.GetType() {
1071 case ai.ToolResultContentTypeText:
1072 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1073 if !ok {
1074 warnings = append(warnings, ai.CallWarning{
1075 Type: ai.CallWarningTypeOther,
1076 Message: "tool result output does not have the right type",
1077 })
1078 continue
1079 }
1080 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1081 case ai.ToolResultContentTypeError:
1082 // TODO: check if better handling is needed
1083 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1084 if !ok {
1085 warnings = append(warnings, ai.CallWarning{
1086 Type: ai.CallWarningTypeOther,
1087 Message: "tool result output does not have the right type",
1088 })
1089 continue
1090 }
1091 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1092 }
1093 }
1094 }
1095 }
1096 return messages, warnings
1097}
1098
1099// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1100func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1101 var annotations []openai.ChatCompletionMessageAnnotation
1102
1103 // Parse the raw JSON to extract annotations
1104 var deltaData map[string]any
1105 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1106 return annotations
1107 }
1108
1109 // Check if annotations exist in the delta
1110 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1111 for _, annotationData := range annotationsData {
1112 if annotationMap, ok := annotationData.(map[string]any); ok {
1113 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1114 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1115 annotation := openai.ChatCompletionMessageAnnotation{
1116 Type: "url_citation",
1117 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1118 URL: urlCitationData["url"].(string),
1119 Title: urlCitationData["title"].(string),
1120 },
1121 }
1122 annotations = append(annotations, annotation)
1123 }
1124 }
1125 }
1126 }
1127 }
1128
1129 return annotations
1130}