1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/ai/ai"
15 "github.com/google/uuid"
16 "github.com/openai/openai-go/v2"
17 "github.com/openai/openai-go/v2/option"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type provider struct {
23 options options
24}
25
26type options struct {
27 baseURL string
28 apiKey string
29 organization string
30 project string
31 name string
32 headers map[string]string
33 client option.HTTPClient
34}
35
36type Option = func(*options)
37
38func New(opts ...Option) ai.Provider {
39 options := options{
40 headers: map[string]string{},
41 }
42 for _, o := range opts {
43 o(&options)
44 }
45
46 options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
47 options.name = cmp.Or(options.name, "openai")
48
49 if options.organization != "" {
50 options.headers["OpenAi-Organization"] = options.organization
51 }
52 if options.project != "" {
53 options.headers["OpenAi-Project"] = options.project
54 }
55
56 return &provider{options: options}
57}
58
59func WithBaseURL(baseURL string) Option {
60 return func(o *options) {
61 o.baseURL = baseURL
62 }
63}
64
65func WithAPIKey(apiKey string) Option {
66 return func(o *options) {
67 o.apiKey = apiKey
68 }
69}
70
71func WithOrganization(organization string) Option {
72 return func(o *options) {
73 o.organization = organization
74 }
75}
76
77func WithProject(project string) Option {
78 return func(o *options) {
79 o.project = project
80 }
81}
82
83func WithName(name string) Option {
84 return func(o *options) {
85 o.name = name
86 }
87}
88
89func WithHeaders(headers map[string]string) Option {
90 return func(o *options) {
91 maps.Copy(o.headers, headers)
92 }
93}
94
95func WithHTTPClient(client option.HTTPClient) Option {
96 return func(o *options) {
97 o.client = client
98 }
99}
100
101// LanguageModel implements ai.Provider.
102func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
103 openaiClientOptions := []option.RequestOption{}
104 if o.options.apiKey != "" {
105 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
106 }
107 if o.options.baseURL != "" {
108 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
109 }
110
111 for key, value := range o.options.headers {
112 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
113 }
114
115 if o.options.client != nil {
116 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
117 }
118
119 return languageModel{
120 modelID: modelID,
121 provider: fmt.Sprintf("%s.chat", o.options.name),
122 options: o.options,
123 client: openai.NewClient(openaiClientOptions...),
124 }, nil
125}
126
127type languageModel struct {
128 provider string
129 modelID string
130 client openai.Client
131 options options
132}
133
134// Model implements ai.LanguageModel.
135func (o languageModel) Model() string {
136 return o.modelID
137}
138
139// Provider implements ai.LanguageModel.
140func (o languageModel) Provider() string {
141 return o.provider
142}
143
144func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
145 params := &openai.ChatCompletionNewParams{}
146 messages, warnings := toPrompt(call.Prompt)
147 providerOptions := &providerOptions{}
148 if v, ok := call.ProviderOptions["openai"]; ok {
149 err := ai.ParseOptions(v, providerOptions)
150 if err != nil {
151 return nil, nil, err
152 }
153 }
154 if call.TopK != nil {
155 warnings = append(warnings, ai.CallWarning{
156 Type: ai.CallWarningTypeUnsupportedSetting,
157 Setting: "top_k",
158 })
159 }
160 params.Messages = messages
161 params.Model = o.modelID
162 if providerOptions.LogitBias != nil {
163 params.LogitBias = providerOptions.LogitBias
164 }
165 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
166 providerOptions.LogProbs = nil
167 }
168 if providerOptions.LogProbs != nil {
169 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
170 }
171 if providerOptions.TopLogProbs != nil {
172 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
173 }
174 if providerOptions.User != nil {
175 params.User = param.NewOpt(*providerOptions.User)
176 }
177 if providerOptions.ParallelToolCalls != nil {
178 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
179 }
180
181 if call.MaxOutputTokens != nil {
182 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
183 }
184 if call.Temperature != nil {
185 params.Temperature = param.NewOpt(*call.Temperature)
186 }
187 if call.TopP != nil {
188 params.TopP = param.NewOpt(*call.TopP)
189 }
190 if call.FrequencyPenalty != nil {
191 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
192 }
193 if call.PresencePenalty != nil {
194 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
195 }
196
197 if providerOptions.MaxCompletionTokens != nil {
198 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
199 }
200
201 if providerOptions.TextVerbosity != nil {
202 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
203 }
204 if providerOptions.Prediction != nil {
205 // Convert map[string]any to ChatCompletionPredictionContentParam
206 if content, ok := providerOptions.Prediction["content"]; ok {
207 if contentStr, ok := content.(string); ok {
208 params.Prediction = openai.ChatCompletionPredictionContentParam{
209 Content: openai.ChatCompletionPredictionContentContentUnionParam{
210 OfString: param.NewOpt(contentStr),
211 },
212 }
213 }
214 }
215 }
216 if providerOptions.Store != nil {
217 params.Store = param.NewOpt(*providerOptions.Store)
218 }
219 if providerOptions.Metadata != nil {
220 // Convert map[string]any to map[string]string
221 metadata := make(map[string]string)
222 for k, v := range providerOptions.Metadata {
223 if str, ok := v.(string); ok {
224 metadata[k] = str
225 }
226 }
227 params.Metadata = metadata
228 }
229 if providerOptions.PromptCacheKey != nil {
230 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
231 }
232 if providerOptions.SafetyIdentifier != nil {
233 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
234 }
235 if providerOptions.ServiceTier != nil {
236 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
237 }
238
239 if providerOptions.ReasoningEffort != nil {
240 switch *providerOptions.ReasoningEffort {
241 case reasoningEffortMinimal:
242 params.ReasoningEffort = shared.ReasoningEffortMinimal
243 case reasoningEffortLow:
244 params.ReasoningEffort = shared.ReasoningEffortLow
245 case reasoningEffortMedium:
246 params.ReasoningEffort = shared.ReasoningEffortMedium
247 case reasoningEffortHigh:
248 params.ReasoningEffort = shared.ReasoningEffortHigh
249 default:
250 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
251 }
252 }
253
254 if isReasoningModel(o.modelID) {
255 // remove unsupported settings for reasoning models
256 // see https://platform.openai.com/docs/guides/reasoning#limitations
257 if call.Temperature != nil {
258 params.Temperature = param.Opt[float64]{}
259 warnings = append(warnings, ai.CallWarning{
260 Type: ai.CallWarningTypeUnsupportedSetting,
261 Setting: "temperature",
262 Details: "temperature is not supported for reasoning models",
263 })
264 }
265 if call.TopP != nil {
266 params.TopP = param.Opt[float64]{}
267 warnings = append(warnings, ai.CallWarning{
268 Type: ai.CallWarningTypeUnsupportedSetting,
269 Setting: "TopP",
270 Details: "TopP is not supported for reasoning models",
271 })
272 }
273 if call.FrequencyPenalty != nil {
274 params.FrequencyPenalty = param.Opt[float64]{}
275 warnings = append(warnings, ai.CallWarning{
276 Type: ai.CallWarningTypeUnsupportedSetting,
277 Setting: "FrequencyPenalty",
278 Details: "FrequencyPenalty is not supported for reasoning models",
279 })
280 }
281 if call.PresencePenalty != nil {
282 params.PresencePenalty = param.Opt[float64]{}
283 warnings = append(warnings, ai.CallWarning{
284 Type: ai.CallWarningTypeUnsupportedSetting,
285 Setting: "PresencePenalty",
286 Details: "PresencePenalty is not supported for reasoning models",
287 })
288 }
289 if providerOptions.LogitBias != nil {
290 params.LogitBias = nil
291 warnings = append(warnings, ai.CallWarning{
292 Type: ai.CallWarningTypeUnsupportedSetting,
293 Setting: "LogitBias",
294 Message: "LogitBias is not supported for reasoning models",
295 })
296 }
297 if providerOptions.LogProbs != nil {
298 params.Logprobs = param.Opt[bool]{}
299 warnings = append(warnings, ai.CallWarning{
300 Type: ai.CallWarningTypeUnsupportedSetting,
301 Setting: "Logprobs",
302 Message: "Logprobs is not supported for reasoning models",
303 })
304 }
305 if providerOptions.TopLogProbs != nil {
306 params.TopLogprobs = param.Opt[int64]{}
307 warnings = append(warnings, ai.CallWarning{
308 Type: ai.CallWarningTypeUnsupportedSetting,
309 Setting: "TopLogprobs",
310 Message: "TopLogprobs is not supported for reasoning models",
311 })
312 }
313
314 // reasoning models use max_completion_tokens instead of max_tokens
315 if call.MaxOutputTokens != nil {
316 if providerOptions.MaxCompletionTokens == nil {
317 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
318 }
319 params.MaxTokens = param.Opt[int64]{}
320 }
321 }
322
323 // Handle search preview models
324 if isSearchPreviewModel(o.modelID) {
325 if call.Temperature != nil {
326 params.Temperature = param.Opt[float64]{}
327 warnings = append(warnings, ai.CallWarning{
328 Type: ai.CallWarningTypeUnsupportedSetting,
329 Setting: "temperature",
330 Details: "temperature is not supported for the search preview models and has been removed.",
331 })
332 }
333 }
334
335 // Handle service tier validation
336 if providerOptions.ServiceTier != nil {
337 serviceTier := *providerOptions.ServiceTier
338 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
339 params.ServiceTier = ""
340 warnings = append(warnings, ai.CallWarning{
341 Type: ai.CallWarningTypeUnsupportedSetting,
342 Setting: "ServiceTier",
343 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
344 })
345 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
346 params.ServiceTier = ""
347 warnings = append(warnings, ai.CallWarning{
348 Type: ai.CallWarningTypeUnsupportedSetting,
349 Setting: "ServiceTier",
350 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
351 })
352 }
353 }
354
355 if len(call.Tools) > 0 {
356 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
357 params.Tools = tools
358 if toolChoice != nil {
359 params.ToolChoice = *toolChoice
360 }
361 warnings = append(warnings, toolWarnings...)
362 }
363 return params, warnings, nil
364}
365
366func (o languageModel) handleError(err error) error {
367 var apiErr *openai.Error
368 if errors.As(err, &apiErr) {
369 requestDump := apiErr.DumpRequest(true)
370 responseDump := apiErr.DumpResponse(true)
371 headers := map[string]string{}
372 for k, h := range apiErr.Response.Header {
373 v := h[len(h)-1]
374 headers[strings.ToLower(k)] = v
375 }
376 return ai.NewAPICallError(
377 apiErr.Message,
378 apiErr.Request.URL.String(),
379 string(requestDump),
380 apiErr.StatusCode,
381 headers,
382 string(responseDump),
383 apiErr,
384 false,
385 )
386 }
387 return err
388}
389
390// Generate implements ai.LanguageModel.
391func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
392 params, warnings, err := o.prepareParams(call)
393 if err != nil {
394 return nil, err
395 }
396 response, err := o.client.Chat.Completions.New(ctx, *params)
397 if err != nil {
398 return nil, o.handleError(err)
399 }
400
401 if len(response.Choices) == 0 {
402 return nil, errors.New("no response generated")
403 }
404 choice := response.Choices[0]
405 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
406 text := choice.Message.Content
407 if text != "" {
408 content = append(content, ai.TextContent{
409 Text: text,
410 })
411 }
412
413 for _, tc := range choice.Message.ToolCalls {
414 toolCallID := tc.ID
415 if toolCallID == "" {
416 toolCallID = uuid.NewString()
417 }
418 content = append(content, ai.ToolCallContent{
419 ProviderExecuted: false, // TODO: update when handling other tools
420 ToolCallID: toolCallID,
421 ToolName: tc.Function.Name,
422 Input: tc.Function.Arguments,
423 })
424 }
425 // Handle annotations/citations
426 for _, annotation := range choice.Message.Annotations {
427 if annotation.Type == "url_citation" {
428 content = append(content, ai.SourceContent{
429 SourceType: ai.SourceTypeURL,
430 ID: uuid.NewString(),
431 URL: annotation.URLCitation.URL,
432 Title: annotation.URLCitation.Title,
433 })
434 }
435 }
436
437 completionTokenDetails := response.Usage.CompletionTokensDetails
438 promptTokenDetails := response.Usage.PromptTokensDetails
439
440 // Build provider metadata
441 providerMetadata := ai.ProviderMetadata{
442 "openai": make(map[string]any),
443 }
444
445 // Add logprobs if available
446 if len(choice.Logprobs.Content) > 0 {
447 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
448 }
449
450 // Add prediction tokens if available
451 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
452 if completionTokenDetails.AcceptedPredictionTokens > 0 {
453 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
454 }
455 if completionTokenDetails.RejectedPredictionTokens > 0 {
456 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
457 }
458 }
459
460 return &ai.Response{
461 Content: content,
462 Usage: ai.Usage{
463 InputTokens: response.Usage.PromptTokens,
464 OutputTokens: response.Usage.CompletionTokens,
465 TotalTokens: response.Usage.TotalTokens,
466 ReasoningTokens: completionTokenDetails.ReasoningTokens,
467 CacheReadTokens: promptTokenDetails.CachedTokens,
468 },
469 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
470 ProviderMetadata: providerMetadata,
471 Warnings: warnings,
472 }, nil
473}
474
475type toolCall struct {
476 id string
477 name string
478 arguments string
479 hasFinished bool
480}
481
482// Stream implements ai.LanguageModel.
483func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
484 params, warnings, err := o.prepareParams(call)
485 if err != nil {
486 return nil, err
487 }
488
489 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
490 IncludeUsage: openai.Bool(true),
491 }
492
493 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
494 isActiveText := false
495 toolCalls := make(map[int64]toolCall)
496
497 // Build provider metadata for streaming
498 streamProviderMetadata := ai.ProviderMetadata{
499 "openai": make(map[string]any),
500 }
501
502 acc := openai.ChatCompletionAccumulator{}
503 var usage ai.Usage
504 return func(yield func(ai.StreamPart) bool) {
505 if len(warnings) > 0 {
506 if !yield(ai.StreamPart{
507 Type: ai.StreamPartTypeWarnings,
508 Warnings: warnings,
509 }) {
510 return
511 }
512 }
513 for stream.Next() {
514 chunk := stream.Current()
515 acc.AddChunk(chunk)
516 if chunk.Usage.TotalTokens > 0 {
517 // we do this here because the acc does not add prompt details
518 completionTokenDetails := chunk.Usage.CompletionTokensDetails
519 promptTokenDetails := chunk.Usage.PromptTokensDetails
520 usage = ai.Usage{
521 InputTokens: chunk.Usage.PromptTokens,
522 OutputTokens: chunk.Usage.CompletionTokens,
523 TotalTokens: chunk.Usage.TotalTokens,
524 ReasoningTokens: completionTokenDetails.ReasoningTokens,
525 CacheReadTokens: promptTokenDetails.CachedTokens,
526 }
527
528 // Add prediction tokens if available
529 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
530 if completionTokenDetails.AcceptedPredictionTokens > 0 {
531 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
532 }
533 if completionTokenDetails.RejectedPredictionTokens > 0 {
534 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
535 }
536 }
537 }
538 if len(chunk.Choices) == 0 {
539 continue
540 }
541 for _, choice := range chunk.Choices {
542 switch {
543 case choice.Delta.Content != "":
544 if !isActiveText {
545 isActiveText = true
546 if !yield(ai.StreamPart{
547 Type: ai.StreamPartTypeTextStart,
548 ID: "0",
549 }) {
550 return
551 }
552 }
553 if !yield(ai.StreamPart{
554 Type: ai.StreamPartTypeTextDelta,
555 ID: "0",
556 Delta: choice.Delta.Content,
557 }) {
558 return
559 }
560 case len(choice.Delta.ToolCalls) > 0:
561 if isActiveText {
562 isActiveText = false
563 if !yield(ai.StreamPart{
564 Type: ai.StreamPartTypeTextEnd,
565 ID: "0",
566 }) {
567 return
568 }
569 }
570
571 for _, toolCallDelta := range choice.Delta.ToolCalls {
572 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
573 if existingToolCall.hasFinished {
574 continue
575 }
576 if toolCallDelta.Function.Arguments != "" {
577 existingToolCall.arguments += toolCallDelta.Function.Arguments
578 }
579 if !yield(ai.StreamPart{
580 Type: ai.StreamPartTypeToolInputDelta,
581 ID: existingToolCall.id,
582 Delta: toolCallDelta.Function.Arguments,
583 }) {
584 return
585 }
586 toolCalls[toolCallDelta.Index] = existingToolCall
587 if isValidJSON(existingToolCall.arguments) {
588 if !yield(ai.StreamPart{
589 Type: ai.StreamPartTypeToolInputEnd,
590 ID: existingToolCall.id,
591 }) {
592 return
593 }
594
595 if !yield(ai.StreamPart{
596 Type: ai.StreamPartTypeToolCall,
597 ID: existingToolCall.id,
598 ToolCallName: existingToolCall.name,
599 ToolCallInput: existingToolCall.arguments,
600 }) {
601 return
602 }
603 existingToolCall.hasFinished = true
604 toolCalls[toolCallDelta.Index] = existingToolCall
605 }
606 } else {
607 // Does not exist
608 var err error
609 if toolCallDelta.Type != "function" {
610 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
611 }
612 if toolCallDelta.ID == "" {
613 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
614 }
615 if toolCallDelta.Function.Name == "" {
616 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
617 }
618 if err != nil {
619 yield(ai.StreamPart{
620 Type: ai.StreamPartTypeError,
621 Error: o.handleError(stream.Err()),
622 })
623 return
624 }
625
626 if !yield(ai.StreamPart{
627 Type: ai.StreamPartTypeToolInputStart,
628 ID: toolCallDelta.ID,
629 ToolCallName: toolCallDelta.Function.Name,
630 }) {
631 return
632 }
633 toolCalls[toolCallDelta.Index] = toolCall{
634 id: toolCallDelta.ID,
635 name: toolCallDelta.Function.Name,
636 arguments: toolCallDelta.Function.Arguments,
637 }
638
639 exTc := toolCalls[toolCallDelta.Index]
640 if exTc.arguments != "" {
641 if !yield(ai.StreamPart{
642 Type: ai.StreamPartTypeToolInputDelta,
643 ID: exTc.id,
644 Delta: exTc.arguments,
645 }) {
646 return
647 }
648 if isValidJSON(toolCalls[toolCallDelta.Index].arguments) {
649 if !yield(ai.StreamPart{
650 Type: ai.StreamPartTypeToolInputEnd,
651 ID: toolCallDelta.ID,
652 }) {
653 return
654 }
655
656 if !yield(ai.StreamPart{
657 Type: ai.StreamPartTypeToolCall,
658 ID: exTc.id,
659 ToolCallName: exTc.name,
660 ToolCallInput: exTc.arguments,
661 }) {
662 return
663 }
664 exTc.hasFinished = true
665 toolCalls[toolCallDelta.Index] = exTc
666 }
667 }
668 continue
669 }
670 }
671 }
672 }
673
674 // Check for annotations in the delta's raw JSON
675 for _, choice := range chunk.Choices {
676 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
677 for _, annotation := range annotations {
678 if annotation.Type == "url_citation" {
679 if !yield(ai.StreamPart{
680 Type: ai.StreamPartTypeSource,
681 ID: uuid.NewString(),
682 SourceType: ai.SourceTypeURL,
683 URL: annotation.URLCitation.URL,
684 Title: annotation.URLCitation.Title,
685 }) {
686 return
687 }
688 }
689 }
690 }
691 }
692 }
693 err := stream.Err()
694 if err == nil || errors.Is(err, io.EOF) {
695 // finished
696 if isActiveText {
697 isActiveText = false
698 if !yield(ai.StreamPart{
699 Type: ai.StreamPartTypeTextEnd,
700 ID: "0",
701 }) {
702 return
703 }
704 }
705
706 // Add logprobs if available
707 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
708 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
709 }
710
711 // Handle annotations/citations from accumulated response
712 if len(acc.Choices) > 0 {
713 for _, annotation := range acc.Choices[0].Message.Annotations {
714 if annotation.Type == "url_citation" {
715 if !yield(ai.StreamPart{
716 Type: ai.StreamPartTypeSource,
717 ID: acc.ID,
718 SourceType: ai.SourceTypeURL,
719 URL: annotation.URLCitation.URL,
720 Title: annotation.URLCitation.Title,
721 }) {
722 return
723 }
724 }
725 }
726 }
727
728 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
729 yield(ai.StreamPart{
730 Type: ai.StreamPartTypeFinish,
731 Usage: usage,
732 FinishReason: finishReason,
733 ProviderMetadata: streamProviderMetadata,
734 })
735 return
736 } else {
737 yield(ai.StreamPart{
738 Type: ai.StreamPartTypeError,
739 Error: o.handleError(err),
740 })
741 return
742 }
743 }, nil
744}
745
746func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
747 switch finishReason {
748 case "stop":
749 return ai.FinishReasonStop
750 case "length":
751 return ai.FinishReasonLength
752 case "content_filter":
753 return ai.FinishReasonContentFilter
754 case "function_call", "tool_calls":
755 return ai.FinishReasonToolCalls
756 default:
757 return ai.FinishReasonUnknown
758 }
759}
760
761func isReasoningModel(modelID string) bool {
762 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
763}
764
765func isSearchPreviewModel(modelID string) bool {
766 return strings.Contains(modelID, "search-preview")
767}
768
769func supportsFlexProcessing(modelID string) bool {
770 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
771}
772
773func supportsPriorityProcessing(modelID string) bool {
774 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
775 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
776 strings.HasPrefix(modelID, "o4-mini")
777}
778
779func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
780 for _, tool := range tools {
781 if tool.GetType() == ai.ToolTypeFunction {
782 ft, ok := tool.(ai.FunctionTool)
783 if !ok {
784 continue
785 }
786 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
787 OfFunction: &openai.ChatCompletionFunctionToolParam{
788 Function: shared.FunctionDefinitionParam{
789 Name: ft.Name,
790 Description: param.NewOpt(ft.Description),
791 Parameters: openai.FunctionParameters(ft.InputSchema),
792 Strict: param.NewOpt(false),
793 },
794 Type: "function",
795 },
796 })
797 continue
798 }
799
800 // TODO: handle provider tool calls
801 warnings = append(warnings, ai.CallWarning{
802 Type: ai.CallWarningTypeUnsupportedTool,
803 Tool: tool,
804 Message: "tool is not supported",
805 })
806 }
807 if toolChoice == nil {
808 return openAiTools, openAiToolChoice, warnings
809 }
810
811 switch *toolChoice {
812 case ai.ToolChoiceAuto:
813 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
814 OfAuto: param.NewOpt("auto"),
815 }
816 case ai.ToolChoiceNone:
817 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
818 OfAuto: param.NewOpt("none"),
819 }
820 default:
821 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
822 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
823 Type: "function",
824 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
825 Name: string(*toolChoice),
826 },
827 },
828 }
829 }
830 return openAiTools, openAiToolChoice, warnings
831}
832
833func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
834 var messages []openai.ChatCompletionMessageParamUnion
835 var warnings []ai.CallWarning
836 for _, msg := range prompt {
837 switch msg.Role {
838 case ai.MessageRoleSystem:
839 var systemPromptParts []string
840 for _, c := range msg.Content {
841 if c.GetType() != ai.ContentTypeText {
842 warnings = append(warnings, ai.CallWarning{
843 Type: ai.CallWarningTypeOther,
844 Message: "system prompt can only have text content",
845 })
846 continue
847 }
848 textPart, ok := ai.AsContentType[ai.TextPart](c)
849 if !ok {
850 warnings = append(warnings, ai.CallWarning{
851 Type: ai.CallWarningTypeOther,
852 Message: "system prompt text part does not have the right type",
853 })
854 continue
855 }
856 text := textPart.Text
857 if strings.TrimSpace(text) != "" {
858 systemPromptParts = append(systemPromptParts, textPart.Text)
859 }
860 }
861 if len(systemPromptParts) == 0 {
862 warnings = append(warnings, ai.CallWarning{
863 Type: ai.CallWarningTypeOther,
864 Message: "system prompt has no text parts",
865 })
866 continue
867 }
868 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
869 case ai.MessageRoleUser:
870 // simple user message just text content
871 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
872 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
873 if !ok {
874 warnings = append(warnings, ai.CallWarning{
875 Type: ai.CallWarningTypeOther,
876 Message: "user message text part does not have the right type",
877 })
878 continue
879 }
880 messages = append(messages, openai.UserMessage(textPart.Text))
881 continue
882 }
883 // text content and attachments
884 // for now we only support image content later we need to check
885 // TODO: add the supported media types to the language model so we
886 // can use that to validate the data here.
887 var content []openai.ChatCompletionContentPartUnionParam
888 for _, c := range msg.Content {
889 switch c.GetType() {
890 case ai.ContentTypeText:
891 textPart, ok := ai.AsContentType[ai.TextPart](c)
892 if !ok {
893 warnings = append(warnings, ai.CallWarning{
894 Type: ai.CallWarningTypeOther,
895 Message: "user message text part does not have the right type",
896 })
897 continue
898 }
899 content = append(content, openai.ChatCompletionContentPartUnionParam{
900 OfText: &openai.ChatCompletionContentPartTextParam{
901 Text: textPart.Text,
902 },
903 })
904 case ai.ContentTypeFile:
905 filePart, ok := ai.AsContentType[ai.FilePart](c)
906 if !ok {
907 warnings = append(warnings, ai.CallWarning{
908 Type: ai.CallWarningTypeOther,
909 Message: "user message file part does not have the right type",
910 })
911 continue
912 }
913
914 switch {
915 case strings.HasPrefix(filePart.MediaType, "image/"):
916 // Handle image files
917 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
918 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
919 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
920
921 // Check for provider-specific options like image detail
922 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
923 if detail, ok := providerOptions["imageDetail"].(string); ok {
924 imageURL.Detail = detail
925 }
926 }
927
928 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
929 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
930
931 case filePart.MediaType == "audio/wav":
932 // Handle WAV audio files
933 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
934 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
935 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
936 Data: base64Encoded,
937 Format: "wav",
938 },
939 }
940 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
941
942 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
943 // Handle MP3 audio files
944 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
945 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
946 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
947 Data: base64Encoded,
948 Format: "mp3",
949 },
950 }
951 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
952
953 case filePart.MediaType == "application/pdf":
954 // Handle PDF files
955 dataStr := string(filePart.Data)
956
957 // Check if data looks like a file ID (starts with "file-")
958 if strings.HasPrefix(dataStr, "file-") {
959 fileBlock := openai.ChatCompletionContentPartFileParam{
960 File: openai.ChatCompletionContentPartFileFileParam{
961 FileID: param.NewOpt(dataStr),
962 },
963 }
964 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
965 } else {
966 // Handle as base64 data
967 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
968 data := "data:application/pdf;base64," + base64Encoded
969
970 filename := filePart.Filename
971 if filename == "" {
972 // Generate default filename based on content index
973 filename = fmt.Sprintf("part-%d.pdf", len(content))
974 }
975
976 fileBlock := openai.ChatCompletionContentPartFileParam{
977 File: openai.ChatCompletionContentPartFileFileParam{
978 Filename: param.NewOpt(filename),
979 FileData: param.NewOpt(data),
980 },
981 }
982 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
983 }
984
985 default:
986 warnings = append(warnings, ai.CallWarning{
987 Type: ai.CallWarningTypeOther,
988 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
989 })
990 }
991 }
992 }
993 messages = append(messages, openai.UserMessage(content))
994 case ai.MessageRoleAssistant:
995 // simple assistant message just text content
996 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
997 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
998 if !ok {
999 warnings = append(warnings, ai.CallWarning{
1000 Type: ai.CallWarningTypeOther,
1001 Message: "assistant message text part does not have the right type",
1002 })
1003 continue
1004 }
1005 messages = append(messages, openai.AssistantMessage(textPart.Text))
1006 continue
1007 }
1008 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1009 Role: "assistant",
1010 }
1011 for _, c := range msg.Content {
1012 switch c.GetType() {
1013 case ai.ContentTypeText:
1014 textPart, ok := ai.AsContentType[ai.TextPart](c)
1015 if !ok {
1016 warnings = append(warnings, ai.CallWarning{
1017 Type: ai.CallWarningTypeOther,
1018 Message: "assistant message text part does not have the right type",
1019 })
1020 continue
1021 }
1022 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1023 OfString: param.NewOpt(textPart.Text),
1024 }
1025 case ai.ContentTypeToolCall:
1026 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1027 if !ok {
1028 warnings = append(warnings, ai.CallWarning{
1029 Type: ai.CallWarningTypeOther,
1030 Message: "assistant message tool part does not have the right type",
1031 })
1032 continue
1033 }
1034 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1035 openai.ChatCompletionMessageToolCallUnionParam{
1036 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1037 ID: toolCallPart.ToolCallID,
1038 Type: "function",
1039 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1040 Name: toolCallPart.ToolName,
1041 Arguments: toolCallPart.Input,
1042 },
1043 },
1044 })
1045 }
1046 }
1047 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1048 OfAssistant: &assistantMsg,
1049 })
1050 case ai.MessageRoleTool:
1051 for _, c := range msg.Content {
1052 if c.GetType() != ai.ContentTypeToolResult {
1053 warnings = append(warnings, ai.CallWarning{
1054 Type: ai.CallWarningTypeOther,
1055 Message: "tool message can only have tool result content",
1056 })
1057 continue
1058 }
1059
1060 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1061 if !ok {
1062 warnings = append(warnings, ai.CallWarning{
1063 Type: ai.CallWarningTypeOther,
1064 Message: "tool message result part does not have the right type",
1065 })
1066 continue
1067 }
1068
1069 switch toolResultPart.Output.GetType() {
1070 case ai.ToolResultContentTypeText:
1071 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1072 if !ok {
1073 warnings = append(warnings, ai.CallWarning{
1074 Type: ai.CallWarningTypeOther,
1075 Message: "tool result output does not have the right type",
1076 })
1077 continue
1078 }
1079 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1080 case ai.ToolResultContentTypeError:
1081 // TODO: check if better handling is needed
1082 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1083 if !ok {
1084 warnings = append(warnings, ai.CallWarning{
1085 Type: ai.CallWarningTypeOther,
1086 Message: "tool result output does not have the right type",
1087 })
1088 continue
1089 }
1090 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1091 }
1092 }
1093 }
1094 }
1095 return messages, warnings
1096}
1097
1098// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1099func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1100 var annotations []openai.ChatCompletionMessageAnnotation
1101
1102 // Parse the raw JSON to extract annotations
1103 var deltaData map[string]any
1104 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1105 return annotations
1106 }
1107
1108 // Check if annotations exist in the delta
1109 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1110 for _, annotationData := range annotationsData {
1111 if annotationMap, ok := annotationData.(map[string]any); ok {
1112 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1113 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1114 annotation := openai.ChatCompletionMessageAnnotation{
1115 Type: "url_citation",
1116 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1117 URL: urlCitationData["url"].(string),
1118 Title: urlCitationData["title"].(string),
1119 },
1120 }
1121 annotations = append(annotations, annotation)
1122 }
1123 }
1124 }
1125 }
1126 }
1127
1128 return annotations
1129}