1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/ai/ai"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23type provider struct {
24 options options
25}
26
27type options struct {
28 baseURL string
29 apiKey string
30 organization string
31 project string
32 name string
33 headers map[string]string
34 client option.HTTPClient
35}
36
37type Option = func(*options)
38
39func New(opts ...Option) ai.Provider {
40 options := options{
41 headers: map[string]string{},
42 }
43 for _, o := range opts {
44 o(&options)
45 }
46
47 options.baseURL = cmp.Or(options.baseURL, "https://api.openai.com/v1")
48 options.name = cmp.Or(options.name, "openai")
49
50 if options.organization != "" {
51 options.headers["OpenAi-Organization"] = options.organization
52 }
53 if options.project != "" {
54 options.headers["OpenAi-Project"] = options.project
55 }
56
57 return &provider{options: options}
58}
59
60func WithBaseURL(baseURL string) Option {
61 return func(o *options) {
62 o.baseURL = baseURL
63 }
64}
65
66func WithAPIKey(apiKey string) Option {
67 return func(o *options) {
68 o.apiKey = apiKey
69 }
70}
71
72func WithOrganization(organization string) Option {
73 return func(o *options) {
74 o.organization = organization
75 }
76}
77
78func WithProject(project string) Option {
79 return func(o *options) {
80 o.project = project
81 }
82}
83
84func WithName(name string) Option {
85 return func(o *options) {
86 o.name = name
87 }
88}
89
90func WithHeaders(headers map[string]string) Option {
91 return func(o *options) {
92 maps.Copy(o.headers, headers)
93 }
94}
95
96func WithHTTPClient(client option.HTTPClient) Option {
97 return func(o *options) {
98 o.client = client
99 }
100}
101
102// LanguageModel implements ai.Provider.
103func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
104 openaiClientOptions := []option.RequestOption{}
105 if o.options.apiKey != "" {
106 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
107 }
108 if o.options.baseURL != "" {
109 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
110 }
111
112 for key, value := range o.options.headers {
113 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
114 }
115
116 if o.options.client != nil {
117 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
118 }
119
120 return languageModel{
121 modelID: modelID,
122 provider: fmt.Sprintf("%s.chat", o.options.name),
123 options: o.options,
124 client: openai.NewClient(openaiClientOptions...),
125 }, nil
126}
127
128type languageModel struct {
129 provider string
130 modelID string
131 client openai.Client
132 options options
133}
134
135// Model implements ai.LanguageModel.
136func (o languageModel) Model() string {
137 return o.modelID
138}
139
140// Provider implements ai.LanguageModel.
141func (o languageModel) Provider() string {
142 return o.provider
143}
144
145func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
146 params := &openai.ChatCompletionNewParams{}
147 messages, warnings := toPrompt(call.Prompt)
148 providerOptions := &ProviderOptions{}
149 if v, ok := call.ProviderOptions["openai"]; ok {
150 providerOptions, ok = v.(*ProviderOptions)
151 if !ok {
152 return nil, nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
153 }
154 }
155 if call.TopK != nil {
156 warnings = append(warnings, ai.CallWarning{
157 Type: ai.CallWarningTypeUnsupportedSetting,
158 Setting: "top_k",
159 })
160 }
161 params.Messages = messages
162 params.Model = o.modelID
163 if providerOptions.LogitBias != nil {
164 params.LogitBias = providerOptions.LogitBias
165 }
166 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
167 providerOptions.LogProbs = nil
168 }
169 if providerOptions.LogProbs != nil {
170 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
171 }
172 if providerOptions.TopLogProbs != nil {
173 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
174 }
175 if providerOptions.User != nil {
176 params.User = param.NewOpt(*providerOptions.User)
177 }
178 if providerOptions.ParallelToolCalls != nil {
179 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
180 }
181
182 if call.MaxOutputTokens != nil {
183 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
184 }
185 if call.Temperature != nil {
186 params.Temperature = param.NewOpt(*call.Temperature)
187 }
188 if call.TopP != nil {
189 params.TopP = param.NewOpt(*call.TopP)
190 }
191 if call.FrequencyPenalty != nil {
192 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
193 }
194 if call.PresencePenalty != nil {
195 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
196 }
197
198 if providerOptions.MaxCompletionTokens != nil {
199 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
200 }
201
202 if providerOptions.TextVerbosity != nil {
203 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
204 }
205 if providerOptions.Prediction != nil {
206 // Convert map[string]any to ChatCompletionPredictionContentParam
207 if content, ok := providerOptions.Prediction["content"]; ok {
208 if contentStr, ok := content.(string); ok {
209 params.Prediction = openai.ChatCompletionPredictionContentParam{
210 Content: openai.ChatCompletionPredictionContentContentUnionParam{
211 OfString: param.NewOpt(contentStr),
212 },
213 }
214 }
215 }
216 }
217 if providerOptions.Store != nil {
218 params.Store = param.NewOpt(*providerOptions.Store)
219 }
220 if providerOptions.Metadata != nil {
221 // Convert map[string]any to map[string]string
222 metadata := make(map[string]string)
223 for k, v := range providerOptions.Metadata {
224 if str, ok := v.(string); ok {
225 metadata[k] = str
226 }
227 }
228 params.Metadata = metadata
229 }
230 if providerOptions.PromptCacheKey != nil {
231 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
232 }
233 if providerOptions.SafetyIdentifier != nil {
234 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
235 }
236 if providerOptions.ServiceTier != nil {
237 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
238 }
239
240 if providerOptions.ReasoningEffort != nil {
241 switch *providerOptions.ReasoningEffort {
242 case ReasoningEffortMinimal:
243 params.ReasoningEffort = shared.ReasoningEffortMinimal
244 case ReasoningEffortLow:
245 params.ReasoningEffort = shared.ReasoningEffortLow
246 case ReasoningEffortMedium:
247 params.ReasoningEffort = shared.ReasoningEffortMedium
248 case ReasoningEffortHigh:
249 params.ReasoningEffort = shared.ReasoningEffortHigh
250 default:
251 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
252 }
253 }
254
255 if isReasoningModel(o.modelID) {
256 // remove unsupported settings for reasoning models
257 // see https://platform.openai.com/docs/guides/reasoning#limitations
258 if call.Temperature != nil {
259 params.Temperature = param.Opt[float64]{}
260 warnings = append(warnings, ai.CallWarning{
261 Type: ai.CallWarningTypeUnsupportedSetting,
262 Setting: "temperature",
263 Details: "temperature is not supported for reasoning models",
264 })
265 }
266 if call.TopP != nil {
267 params.TopP = param.Opt[float64]{}
268 warnings = append(warnings, ai.CallWarning{
269 Type: ai.CallWarningTypeUnsupportedSetting,
270 Setting: "TopP",
271 Details: "TopP is not supported for reasoning models",
272 })
273 }
274 if call.FrequencyPenalty != nil {
275 params.FrequencyPenalty = param.Opt[float64]{}
276 warnings = append(warnings, ai.CallWarning{
277 Type: ai.CallWarningTypeUnsupportedSetting,
278 Setting: "FrequencyPenalty",
279 Details: "FrequencyPenalty is not supported for reasoning models",
280 })
281 }
282 if call.PresencePenalty != nil {
283 params.PresencePenalty = param.Opt[float64]{}
284 warnings = append(warnings, ai.CallWarning{
285 Type: ai.CallWarningTypeUnsupportedSetting,
286 Setting: "PresencePenalty",
287 Details: "PresencePenalty is not supported for reasoning models",
288 })
289 }
290 if providerOptions.LogitBias != nil {
291 params.LogitBias = nil
292 warnings = append(warnings, ai.CallWarning{
293 Type: ai.CallWarningTypeUnsupportedSetting,
294 Setting: "LogitBias",
295 Message: "LogitBias is not supported for reasoning models",
296 })
297 }
298 if providerOptions.LogProbs != nil {
299 params.Logprobs = param.Opt[bool]{}
300 warnings = append(warnings, ai.CallWarning{
301 Type: ai.CallWarningTypeUnsupportedSetting,
302 Setting: "Logprobs",
303 Message: "Logprobs is not supported for reasoning models",
304 })
305 }
306 if providerOptions.TopLogProbs != nil {
307 params.TopLogprobs = param.Opt[int64]{}
308 warnings = append(warnings, ai.CallWarning{
309 Type: ai.CallWarningTypeUnsupportedSetting,
310 Setting: "TopLogprobs",
311 Message: "TopLogprobs is not supported for reasoning models",
312 })
313 }
314
315 // reasoning models use max_completion_tokens instead of max_tokens
316 if call.MaxOutputTokens != nil {
317 if providerOptions.MaxCompletionTokens == nil {
318 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
319 }
320 params.MaxTokens = param.Opt[int64]{}
321 }
322 }
323
324 // Handle search preview models
325 if isSearchPreviewModel(o.modelID) {
326 if call.Temperature != nil {
327 params.Temperature = param.Opt[float64]{}
328 warnings = append(warnings, ai.CallWarning{
329 Type: ai.CallWarningTypeUnsupportedSetting,
330 Setting: "temperature",
331 Details: "temperature is not supported for the search preview models and has been removed.",
332 })
333 }
334 }
335
336 // Handle service tier validation
337 if providerOptions.ServiceTier != nil {
338 serviceTier := *providerOptions.ServiceTier
339 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
340 params.ServiceTier = ""
341 warnings = append(warnings, ai.CallWarning{
342 Type: ai.CallWarningTypeUnsupportedSetting,
343 Setting: "ServiceTier",
344 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
345 })
346 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
347 params.ServiceTier = ""
348 warnings = append(warnings, ai.CallWarning{
349 Type: ai.CallWarningTypeUnsupportedSetting,
350 Setting: "ServiceTier",
351 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
352 })
353 }
354 }
355
356 if len(call.Tools) > 0 {
357 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
358 params.Tools = tools
359 if toolChoice != nil {
360 params.ToolChoice = *toolChoice
361 }
362 warnings = append(warnings, toolWarnings...)
363 }
364 return params, warnings, nil
365}
366
367func (o languageModel) handleError(err error) error {
368 var apiErr *openai.Error
369 if errors.As(err, &apiErr) {
370 requestDump := apiErr.DumpRequest(true)
371 responseDump := apiErr.DumpResponse(true)
372 headers := map[string]string{}
373 for k, h := range apiErr.Response.Header {
374 v := h[len(h)-1]
375 headers[strings.ToLower(k)] = v
376 }
377 return ai.NewAPICallError(
378 apiErr.Message,
379 apiErr.Request.URL.String(),
380 string(requestDump),
381 apiErr.StatusCode,
382 headers,
383 string(responseDump),
384 apiErr,
385 false,
386 )
387 }
388 return err
389}
390
391// Generate implements ai.LanguageModel.
392func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
393 params, warnings, err := o.prepareParams(call)
394 if err != nil {
395 return nil, err
396 }
397 response, err := o.client.Chat.Completions.New(ctx, *params)
398 if err != nil {
399 return nil, o.handleError(err)
400 }
401
402 if len(response.Choices) == 0 {
403 return nil, errors.New("no response generated")
404 }
405 choice := response.Choices[0]
406 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
407 text := choice.Message.Content
408 if text != "" {
409 content = append(content, ai.TextContent{
410 Text: text,
411 })
412 }
413
414 for _, tc := range choice.Message.ToolCalls {
415 toolCallID := tc.ID
416 if toolCallID == "" {
417 toolCallID = uuid.NewString()
418 }
419 content = append(content, ai.ToolCallContent{
420 ProviderExecuted: false, // TODO: update when handling other tools
421 ToolCallID: toolCallID,
422 ToolName: tc.Function.Name,
423 Input: tc.Function.Arguments,
424 })
425 }
426 // Handle annotations/citations
427 for _, annotation := range choice.Message.Annotations {
428 if annotation.Type == "url_citation" {
429 content = append(content, ai.SourceContent{
430 SourceType: ai.SourceTypeURL,
431 ID: uuid.NewString(),
432 URL: annotation.URLCitation.URL,
433 Title: annotation.URLCitation.Title,
434 })
435 }
436 }
437
438 completionTokenDetails := response.Usage.CompletionTokensDetails
439 promptTokenDetails := response.Usage.PromptTokensDetails
440
441 // Build provider metadata
442 providerMetadata := &ProviderMetadata{}
443 // Add logprobs if available
444 if len(choice.Logprobs.Content) > 0 {
445 providerMetadata.Logprobs = choice.Logprobs.Content
446 }
447
448 // Add prediction tokens if available
449 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
450 if completionTokenDetails.AcceptedPredictionTokens > 0 {
451 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
452 }
453 if completionTokenDetails.RejectedPredictionTokens > 0 {
454 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
455 }
456 }
457
458 return &ai.Response{
459 Content: content,
460 Usage: ai.Usage{
461 InputTokens: response.Usage.PromptTokens,
462 OutputTokens: response.Usage.CompletionTokens,
463 TotalTokens: response.Usage.TotalTokens,
464 ReasoningTokens: completionTokenDetails.ReasoningTokens,
465 CacheReadTokens: promptTokenDetails.CachedTokens,
466 },
467 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
468 ProviderMetadata: ai.ProviderMetadata{
469 "openai": providerMetadata,
470 },
471 Warnings: warnings,
472 }, nil
473}
474
475type toolCall struct {
476 id string
477 name string
478 arguments string
479 hasFinished bool
480}
481
482// Stream implements ai.LanguageModel.
483func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
484 params, warnings, err := o.prepareParams(call)
485 if err != nil {
486 return nil, err
487 }
488
489 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
490 IncludeUsage: openai.Bool(true),
491 }
492
493 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
494 isActiveText := false
495 toolCalls := make(map[int64]toolCall)
496
497 // Build provider metadata for streaming
498 streamProviderMetadata := &ProviderMetadata{}
499 acc := openai.ChatCompletionAccumulator{}
500 var usage ai.Usage
501 return func(yield func(ai.StreamPart) bool) {
502 if len(warnings) > 0 {
503 if !yield(ai.StreamPart{
504 Type: ai.StreamPartTypeWarnings,
505 Warnings: warnings,
506 }) {
507 return
508 }
509 }
510 for stream.Next() {
511 chunk := stream.Current()
512 acc.AddChunk(chunk)
513 if chunk.Usage.TotalTokens > 0 {
514 // we do this here because the acc does not add prompt details
515 completionTokenDetails := chunk.Usage.CompletionTokensDetails
516 promptTokenDetails := chunk.Usage.PromptTokensDetails
517 usage = ai.Usage{
518 InputTokens: chunk.Usage.PromptTokens,
519 OutputTokens: chunk.Usage.CompletionTokens,
520 TotalTokens: chunk.Usage.TotalTokens,
521 ReasoningTokens: completionTokenDetails.ReasoningTokens,
522 CacheReadTokens: promptTokenDetails.CachedTokens,
523 }
524
525 // Add prediction tokens if available
526 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
527 if completionTokenDetails.AcceptedPredictionTokens > 0 {
528 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
529 }
530 if completionTokenDetails.RejectedPredictionTokens > 0 {
531 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
532 }
533 }
534 }
535 if len(chunk.Choices) == 0 {
536 continue
537 }
538 for _, choice := range chunk.Choices {
539 switch {
540 case choice.Delta.Content != "":
541 if !isActiveText {
542 isActiveText = true
543 if !yield(ai.StreamPart{
544 Type: ai.StreamPartTypeTextStart,
545 ID: "0",
546 }) {
547 return
548 }
549 }
550 if !yield(ai.StreamPart{
551 Type: ai.StreamPartTypeTextDelta,
552 ID: "0",
553 Delta: choice.Delta.Content,
554 }) {
555 return
556 }
557 case len(choice.Delta.ToolCalls) > 0:
558 if isActiveText {
559 isActiveText = false
560 if !yield(ai.StreamPart{
561 Type: ai.StreamPartTypeTextEnd,
562 ID: "0",
563 }) {
564 return
565 }
566 }
567
568 for _, toolCallDelta := range choice.Delta.ToolCalls {
569 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
570 if existingToolCall.hasFinished {
571 continue
572 }
573 if toolCallDelta.Function.Arguments != "" {
574 existingToolCall.arguments += toolCallDelta.Function.Arguments
575 }
576 if !yield(ai.StreamPart{
577 Type: ai.StreamPartTypeToolInputDelta,
578 ID: existingToolCall.id,
579 Delta: toolCallDelta.Function.Arguments,
580 }) {
581 return
582 }
583 toolCalls[toolCallDelta.Index] = existingToolCall
584 if xjson.IsValid(existingToolCall.arguments) {
585 if !yield(ai.StreamPart{
586 Type: ai.StreamPartTypeToolInputEnd,
587 ID: existingToolCall.id,
588 }) {
589 return
590 }
591
592 if !yield(ai.StreamPart{
593 Type: ai.StreamPartTypeToolCall,
594 ID: existingToolCall.id,
595 ToolCallName: existingToolCall.name,
596 ToolCallInput: existingToolCall.arguments,
597 }) {
598 return
599 }
600 existingToolCall.hasFinished = true
601 toolCalls[toolCallDelta.Index] = existingToolCall
602 }
603 } else {
604 // Does not exist
605 var err error
606 if toolCallDelta.Type != "function" {
607 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
608 }
609 if toolCallDelta.ID == "" {
610 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
611 }
612 if toolCallDelta.Function.Name == "" {
613 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
614 }
615 if err != nil {
616 yield(ai.StreamPart{
617 Type: ai.StreamPartTypeError,
618 Error: o.handleError(stream.Err()),
619 })
620 return
621 }
622
623 if !yield(ai.StreamPart{
624 Type: ai.StreamPartTypeToolInputStart,
625 ID: toolCallDelta.ID,
626 ToolCallName: toolCallDelta.Function.Name,
627 }) {
628 return
629 }
630 toolCalls[toolCallDelta.Index] = toolCall{
631 id: toolCallDelta.ID,
632 name: toolCallDelta.Function.Name,
633 arguments: toolCallDelta.Function.Arguments,
634 }
635
636 exTc := toolCalls[toolCallDelta.Index]
637 if exTc.arguments != "" {
638 if !yield(ai.StreamPart{
639 Type: ai.StreamPartTypeToolInputDelta,
640 ID: exTc.id,
641 Delta: exTc.arguments,
642 }) {
643 return
644 }
645 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
646 if !yield(ai.StreamPart{
647 Type: ai.StreamPartTypeToolInputEnd,
648 ID: toolCallDelta.ID,
649 }) {
650 return
651 }
652
653 if !yield(ai.StreamPart{
654 Type: ai.StreamPartTypeToolCall,
655 ID: exTc.id,
656 ToolCallName: exTc.name,
657 ToolCallInput: exTc.arguments,
658 }) {
659 return
660 }
661 exTc.hasFinished = true
662 toolCalls[toolCallDelta.Index] = exTc
663 }
664 }
665 continue
666 }
667 }
668 }
669 }
670
671 // Check for annotations in the delta's raw JSON
672 for _, choice := range chunk.Choices {
673 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
674 for _, annotation := range annotations {
675 if annotation.Type == "url_citation" {
676 if !yield(ai.StreamPart{
677 Type: ai.StreamPartTypeSource,
678 ID: uuid.NewString(),
679 SourceType: ai.SourceTypeURL,
680 URL: annotation.URLCitation.URL,
681 Title: annotation.URLCitation.Title,
682 }) {
683 return
684 }
685 }
686 }
687 }
688 }
689 }
690 err := stream.Err()
691 if err == nil || errors.Is(err, io.EOF) {
692 // finished
693 if isActiveText {
694 isActiveText = false
695 if !yield(ai.StreamPart{
696 Type: ai.StreamPartTypeTextEnd,
697 ID: "0",
698 }) {
699 return
700 }
701 }
702
703 // Add logprobs if available
704 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
705 streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
706 }
707
708 // Handle annotations/citations from accumulated response
709 if len(acc.Choices) > 0 {
710 for _, annotation := range acc.Choices[0].Message.Annotations {
711 if annotation.Type == "url_citation" {
712 if !yield(ai.StreamPart{
713 Type: ai.StreamPartTypeSource,
714 ID: acc.ID,
715 SourceType: ai.SourceTypeURL,
716 URL: annotation.URLCitation.URL,
717 Title: annotation.URLCitation.Title,
718 }) {
719 return
720 }
721 }
722 }
723 }
724
725 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
726 yield(ai.StreamPart{
727 Type: ai.StreamPartTypeFinish,
728 Usage: usage,
729 FinishReason: finishReason,
730 ProviderMetadata: ai.ProviderMetadata{
731 "openai": streamProviderMetadata,
732 },
733 })
734 return
735 } else {
736 yield(ai.StreamPart{
737 Type: ai.StreamPartTypeError,
738 Error: o.handleError(err),
739 })
740 return
741 }
742 }, nil
743}
744
745func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
746 switch finishReason {
747 case "stop":
748 return ai.FinishReasonStop
749 case "length":
750 return ai.FinishReasonLength
751 case "content_filter":
752 return ai.FinishReasonContentFilter
753 case "function_call", "tool_calls":
754 return ai.FinishReasonToolCalls
755 default:
756 return ai.FinishReasonUnknown
757 }
758}
759
760func isReasoningModel(modelID string) bool {
761 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
762}
763
764func isSearchPreviewModel(modelID string) bool {
765 return strings.Contains(modelID, "search-preview")
766}
767
768func supportsFlexProcessing(modelID string) bool {
769 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
770}
771
772func supportsPriorityProcessing(modelID string) bool {
773 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
774 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
775 strings.HasPrefix(modelID, "o4-mini")
776}
777
778func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
779 for _, tool := range tools {
780 if tool.GetType() == ai.ToolTypeFunction {
781 ft, ok := tool.(ai.FunctionTool)
782 if !ok {
783 continue
784 }
785 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
786 OfFunction: &openai.ChatCompletionFunctionToolParam{
787 Function: shared.FunctionDefinitionParam{
788 Name: ft.Name,
789 Description: param.NewOpt(ft.Description),
790 Parameters: openai.FunctionParameters(ft.InputSchema),
791 Strict: param.NewOpt(false),
792 },
793 Type: "function",
794 },
795 })
796 continue
797 }
798
799 // TODO: handle provider tool calls
800 warnings = append(warnings, ai.CallWarning{
801 Type: ai.CallWarningTypeUnsupportedTool,
802 Tool: tool,
803 Message: "tool is not supported",
804 })
805 }
806 if toolChoice == nil {
807 return openAiTools, openAiToolChoice, warnings
808 }
809
810 switch *toolChoice {
811 case ai.ToolChoiceAuto:
812 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
813 OfAuto: param.NewOpt("auto"),
814 }
815 case ai.ToolChoiceNone:
816 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
817 OfAuto: param.NewOpt("none"),
818 }
819 default:
820 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
821 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
822 Type: "function",
823 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
824 Name: string(*toolChoice),
825 },
826 },
827 }
828 }
829 return openAiTools, openAiToolChoice, warnings
830}
831
832func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
833 var messages []openai.ChatCompletionMessageParamUnion
834 var warnings []ai.CallWarning
835 for _, msg := range prompt {
836 switch msg.Role {
837 case ai.MessageRoleSystem:
838 var systemPromptParts []string
839 for _, c := range msg.Content {
840 if c.GetType() != ai.ContentTypeText {
841 warnings = append(warnings, ai.CallWarning{
842 Type: ai.CallWarningTypeOther,
843 Message: "system prompt can only have text content",
844 })
845 continue
846 }
847 textPart, ok := ai.AsContentType[ai.TextPart](c)
848 if !ok {
849 warnings = append(warnings, ai.CallWarning{
850 Type: ai.CallWarningTypeOther,
851 Message: "system prompt text part does not have the right type",
852 })
853 continue
854 }
855 text := textPart.Text
856 if strings.TrimSpace(text) != "" {
857 systemPromptParts = append(systemPromptParts, textPart.Text)
858 }
859 }
860 if len(systemPromptParts) == 0 {
861 warnings = append(warnings, ai.CallWarning{
862 Type: ai.CallWarningTypeOther,
863 Message: "system prompt has no text parts",
864 })
865 continue
866 }
867 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
868 case ai.MessageRoleUser:
869 // simple user message just text content
870 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
871 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
872 if !ok {
873 warnings = append(warnings, ai.CallWarning{
874 Type: ai.CallWarningTypeOther,
875 Message: "user message text part does not have the right type",
876 })
877 continue
878 }
879 messages = append(messages, openai.UserMessage(textPart.Text))
880 continue
881 }
882 // text content and attachments
883 // for now we only support image content later we need to check
884 // TODO: add the supported media types to the language model so we
885 // can use that to validate the data here.
886 var content []openai.ChatCompletionContentPartUnionParam
887 for _, c := range msg.Content {
888 switch c.GetType() {
889 case ai.ContentTypeText:
890 textPart, ok := ai.AsContentType[ai.TextPart](c)
891 if !ok {
892 warnings = append(warnings, ai.CallWarning{
893 Type: ai.CallWarningTypeOther,
894 Message: "user message text part does not have the right type",
895 })
896 continue
897 }
898 content = append(content, openai.ChatCompletionContentPartUnionParam{
899 OfText: &openai.ChatCompletionContentPartTextParam{
900 Text: textPart.Text,
901 },
902 })
903 case ai.ContentTypeFile:
904 filePart, ok := ai.AsContentType[ai.FilePart](c)
905 if !ok {
906 warnings = append(warnings, ai.CallWarning{
907 Type: ai.CallWarningTypeOther,
908 Message: "user message file part does not have the right type",
909 })
910 continue
911 }
912
913 switch {
914 case strings.HasPrefix(filePart.MediaType, "image/"):
915 // Handle image files
916 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
917 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
918 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
919
920 // Check for provider-specific options like image detail
921 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
922 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
923 imageURL.Detail = detail.ImageDetail
924 }
925 }
926
927 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
928 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
929
930 case filePart.MediaType == "audio/wav":
931 // Handle WAV audio files
932 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
933 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
934 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
935 Data: base64Encoded,
936 Format: "wav",
937 },
938 }
939 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
940
941 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
942 // Handle MP3 audio files
943 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
944 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
945 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
946 Data: base64Encoded,
947 Format: "mp3",
948 },
949 }
950 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
951
952 case filePart.MediaType == "application/pdf":
953 // Handle PDF files
954 dataStr := string(filePart.Data)
955
956 // Check if data looks like a file ID (starts with "file-")
957 if strings.HasPrefix(dataStr, "file-") {
958 fileBlock := openai.ChatCompletionContentPartFileParam{
959 File: openai.ChatCompletionContentPartFileFileParam{
960 FileID: param.NewOpt(dataStr),
961 },
962 }
963 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
964 } else {
965 // Handle as base64 data
966 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
967 data := "data:application/pdf;base64," + base64Encoded
968
969 filename := filePart.Filename
970 if filename == "" {
971 // Generate default filename based on content index
972 filename = fmt.Sprintf("part-%d.pdf", len(content))
973 }
974
975 fileBlock := openai.ChatCompletionContentPartFileParam{
976 File: openai.ChatCompletionContentPartFileFileParam{
977 Filename: param.NewOpt(filename),
978 FileData: param.NewOpt(data),
979 },
980 }
981 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
982 }
983
984 default:
985 warnings = append(warnings, ai.CallWarning{
986 Type: ai.CallWarningTypeOther,
987 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
988 })
989 }
990 }
991 }
992 messages = append(messages, openai.UserMessage(content))
993 case ai.MessageRoleAssistant:
994 // simple assistant message just text content
995 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
996 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
997 if !ok {
998 warnings = append(warnings, ai.CallWarning{
999 Type: ai.CallWarningTypeOther,
1000 Message: "assistant message text part does not have the right type",
1001 })
1002 continue
1003 }
1004 messages = append(messages, openai.AssistantMessage(textPart.Text))
1005 continue
1006 }
1007 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1008 Role: "assistant",
1009 }
1010 for _, c := range msg.Content {
1011 switch c.GetType() {
1012 case ai.ContentTypeText:
1013 textPart, ok := ai.AsContentType[ai.TextPart](c)
1014 if !ok {
1015 warnings = append(warnings, ai.CallWarning{
1016 Type: ai.CallWarningTypeOther,
1017 Message: "assistant message text part does not have the right type",
1018 })
1019 continue
1020 }
1021 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1022 OfString: param.NewOpt(textPart.Text),
1023 }
1024 case ai.ContentTypeToolCall:
1025 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1026 if !ok {
1027 warnings = append(warnings, ai.CallWarning{
1028 Type: ai.CallWarningTypeOther,
1029 Message: "assistant message tool part does not have the right type",
1030 })
1031 continue
1032 }
1033 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1034 openai.ChatCompletionMessageToolCallUnionParam{
1035 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1036 ID: toolCallPart.ToolCallID,
1037 Type: "function",
1038 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1039 Name: toolCallPart.ToolName,
1040 Arguments: toolCallPart.Input,
1041 },
1042 },
1043 })
1044 }
1045 }
1046 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1047 OfAssistant: &assistantMsg,
1048 })
1049 case ai.MessageRoleTool:
1050 for _, c := range msg.Content {
1051 if c.GetType() != ai.ContentTypeToolResult {
1052 warnings = append(warnings, ai.CallWarning{
1053 Type: ai.CallWarningTypeOther,
1054 Message: "tool message can only have tool result content",
1055 })
1056 continue
1057 }
1058
1059 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1060 if !ok {
1061 warnings = append(warnings, ai.CallWarning{
1062 Type: ai.CallWarningTypeOther,
1063 Message: "tool message result part does not have the right type",
1064 })
1065 continue
1066 }
1067
1068 switch toolResultPart.Output.GetType() {
1069 case ai.ToolResultContentTypeText:
1070 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1071 if !ok {
1072 warnings = append(warnings, ai.CallWarning{
1073 Type: ai.CallWarningTypeOther,
1074 Message: "tool result output does not have the right type",
1075 })
1076 continue
1077 }
1078 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1079 case ai.ToolResultContentTypeError:
1080 // TODO: check if better handling is needed
1081 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1082 if !ok {
1083 warnings = append(warnings, ai.CallWarning{
1084 Type: ai.CallWarningTypeOther,
1085 Message: "tool result output does not have the right type",
1086 })
1087 continue
1088 }
1089 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1090 }
1091 }
1092 }
1093 }
1094 return messages, warnings
1095}
1096
1097// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1098func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1099 var annotations []openai.ChatCompletionMessageAnnotation
1100
1101 // Parse the raw JSON to extract annotations
1102 var deltaData map[string]any
1103 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1104 return annotations
1105 }
1106
1107 // Check if annotations exist in the delta
1108 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1109 for _, annotationData := range annotationsData {
1110 if annotationMap, ok := annotationData.(map[string]any); ok {
1111 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1112 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1113 annotation := openai.ChatCompletionMessageAnnotation{
1114 Type: "url_citation",
1115 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1116 URL: urlCitationData["url"].(string),
1117 Title: urlCitationData["title"].(string),
1118 },
1119 }
1120 annotations = append(annotations, annotation)
1121 }
1122 }
1123 }
1124 }
1125 }
1126
1127 return annotations
1128}