1package openai
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "io"
10 "maps"
11 "strings"
12
13 "github.com/charmbracelet/ai"
14 "github.com/charmbracelet/ai/internal/jsonext"
15 "github.com/google/uuid"
16 "github.com/openai/openai-go/v2"
17 "github.com/openai/openai-go/v2/option"
18 "github.com/openai/openai-go/v2/packages/param"
19 "github.com/openai/openai-go/v2/shared"
20)
21
22type ReasoningEffort string
23
24const (
25 ReasoningEffortMinimal ReasoningEffort = "minimal"
26 ReasoningEffortLow ReasoningEffort = "low"
27 ReasoningEffortMedium ReasoningEffort = "medium"
28 ReasoningEffortHigh ReasoningEffort = "high"
29)
30
31type ProviderOptions struct {
32 LogitBias map[string]int64 `json:"logit_bias"`
33 LogProbs *bool `json:"log_probes"`
34 TopLogProbs *int64 `json:"top_log_probs"`
35 ParallelToolCalls *bool `json:"parallel_tool_calls"`
36 User *string `json:"user"`
37 ReasoningEffort *ReasoningEffort `json:"reasoning_effort"`
38 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
39 TextVerbosity *string `json:"text_verbosity"`
40 Prediction map[string]any `json:"prediction"`
41 Store *bool `json:"store"`
42 Metadata map[string]any `json:"metadata"`
43 PromptCacheKey *string `json:"prompt_cache_key"`
44 SafetyIdentifier *string `json:"safety_identifier"`
45 ServiceTier *string `json:"service_tier"`
46 StructuredOutputs *bool `json:"structured_outputs"`
47}
48
49type provider struct {
50 options options
51}
52
53type options struct {
54 baseURL string
55 apiKey string
56 organization string
57 project string
58 name string
59 headers map[string]string
60 client option.HTTPClient
61}
62
63type Option = func(*options)
64
65func New(opts ...Option) ai.Provider {
66 options := options{
67 headers: map[string]string{},
68 }
69 for _, o := range opts {
70 o(&options)
71 }
72
73 if options.baseURL == "" {
74 options.baseURL = "https://api.openai.com/v1"
75 }
76
77 if options.name == "" {
78 options.name = "openai"
79 }
80
81 if options.organization != "" {
82 options.headers["OpenAi-Organization"] = options.organization
83 }
84
85 if options.project != "" {
86 options.headers["OpenAi-Project"] = options.project
87 }
88
89 return &provider{
90 options: options,
91 }
92}
93
94func WithBaseURL(baseURL string) Option {
95 return func(o *options) {
96 o.baseURL = baseURL
97 }
98}
99
100func WithAPIKey(apiKey string) Option {
101 return func(o *options) {
102 o.apiKey = apiKey
103 }
104}
105
106func WithOrganization(organization string) Option {
107 return func(o *options) {
108 o.organization = organization
109 }
110}
111
112func WithProject(project string) Option {
113 return func(o *options) {
114 o.project = project
115 }
116}
117
118func WithName(name string) Option {
119 return func(o *options) {
120 o.name = name
121 }
122}
123
124func WithHeaders(headers map[string]string) Option {
125 return func(o *options) {
126 maps.Copy(o.headers, headers)
127 }
128}
129
130func WithHTTPClient(client option.HTTPClient) Option {
131 return func(o *options) {
132 o.client = client
133 }
134}
135
136// LanguageModel implements ai.Provider.
137func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
138 openaiClientOptions := []option.RequestOption{}
139 if o.options.apiKey != "" {
140 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
141 }
142 if o.options.baseURL != "" {
143 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
144 }
145
146 for key, value := range o.options.headers {
147 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
148 }
149
150 if o.options.client != nil {
151 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
152 }
153
154 return languageModel{
155 modelID: modelID,
156 provider: fmt.Sprintf("%s.chat", o.options.name),
157 options: o.options,
158 client: openai.NewClient(openaiClientOptions...),
159 }, nil
160}
161
162type languageModel struct {
163 provider string
164 modelID string
165 client openai.Client
166 options options
167}
168
169// Model implements ai.LanguageModel.
170func (o languageModel) Model() string {
171 return o.modelID
172}
173
174// Provider implements ai.LanguageModel.
175func (o languageModel) Provider() string {
176 return o.provider
177}
178
179func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
180 params := &openai.ChatCompletionNewParams{}
181 messages, warnings := toPrompt(call.Prompt)
182 providerOptions := &ProviderOptions{}
183 if v, ok := call.ProviderOptions["openai"]; ok {
184 err := ai.ParseOptions(v, providerOptions)
185 if err != nil {
186 return nil, nil, err
187 }
188 }
189 if call.TopK != nil {
190 warnings = append(warnings, ai.CallWarning{
191 Type: ai.CallWarningTypeUnsupportedSetting,
192 Setting: "top_k",
193 })
194 }
195 params.Messages = messages
196 params.Model = o.modelID
197 if providerOptions.LogitBias != nil {
198 params.LogitBias = providerOptions.LogitBias
199 }
200 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
201 providerOptions.LogProbs = nil
202 }
203 if providerOptions.LogProbs != nil {
204 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
205 }
206 if providerOptions.TopLogProbs != nil {
207 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
208 }
209 if providerOptions.User != nil {
210 params.User = param.NewOpt(*providerOptions.User)
211 }
212 if providerOptions.ParallelToolCalls != nil {
213 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
214 }
215
216 if call.MaxOutputTokens != nil {
217 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
218 }
219 if call.Temperature != nil {
220 params.Temperature = param.NewOpt(*call.Temperature)
221 }
222 if call.TopP != nil {
223 params.TopP = param.NewOpt(*call.TopP)
224 }
225 if call.FrequencyPenalty != nil {
226 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
227 }
228 if call.PresencePenalty != nil {
229 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
230 }
231
232 if providerOptions.MaxCompletionTokens != nil {
233 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
234 }
235
236 if providerOptions.TextVerbosity != nil {
237 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
238 }
239 if providerOptions.Prediction != nil {
240 // Convert map[string]any to ChatCompletionPredictionContentParam
241 if content, ok := providerOptions.Prediction["content"]; ok {
242 if contentStr, ok := content.(string); ok {
243 params.Prediction = openai.ChatCompletionPredictionContentParam{
244 Content: openai.ChatCompletionPredictionContentContentUnionParam{
245 OfString: param.NewOpt(contentStr),
246 },
247 }
248 }
249 }
250 }
251 if providerOptions.Store != nil {
252 params.Store = param.NewOpt(*providerOptions.Store)
253 }
254 if providerOptions.Metadata != nil {
255 // Convert map[string]any to map[string]string
256 metadata := make(map[string]string)
257 for k, v := range providerOptions.Metadata {
258 if str, ok := v.(string); ok {
259 metadata[k] = str
260 }
261 }
262 params.Metadata = metadata
263 }
264 if providerOptions.PromptCacheKey != nil {
265 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
266 }
267 if providerOptions.SafetyIdentifier != nil {
268 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
269 }
270 if providerOptions.ServiceTier != nil {
271 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
272 }
273
274 if providerOptions.ReasoningEffort != nil {
275 switch *providerOptions.ReasoningEffort {
276 case ReasoningEffortMinimal:
277 params.ReasoningEffort = shared.ReasoningEffortMinimal
278 case ReasoningEffortLow:
279 params.ReasoningEffort = shared.ReasoningEffortLow
280 case ReasoningEffortMedium:
281 params.ReasoningEffort = shared.ReasoningEffortMedium
282 case ReasoningEffortHigh:
283 params.ReasoningEffort = shared.ReasoningEffortHigh
284 default:
285 return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
286 }
287 }
288
289 if isReasoningModel(o.modelID) {
290 // remove unsupported settings for reasoning models
291 // see https://platform.openai.com/docs/guides/reasoning#limitations
292 if call.Temperature != nil {
293 params.Temperature = param.Opt[float64]{}
294 warnings = append(warnings, ai.CallWarning{
295 Type: ai.CallWarningTypeUnsupportedSetting,
296 Setting: "temperature",
297 Details: "temperature is not supported for reasoning models",
298 })
299 }
300 if call.TopP != nil {
301 params.TopP = param.Opt[float64]{}
302 warnings = append(warnings, ai.CallWarning{
303 Type: ai.CallWarningTypeUnsupportedSetting,
304 Setting: "TopP",
305 Details: "TopP is not supported for reasoning models",
306 })
307 }
308 if call.FrequencyPenalty != nil {
309 params.FrequencyPenalty = param.Opt[float64]{}
310 warnings = append(warnings, ai.CallWarning{
311 Type: ai.CallWarningTypeUnsupportedSetting,
312 Setting: "FrequencyPenalty",
313 Details: "FrequencyPenalty is not supported for reasoning models",
314 })
315 }
316 if call.PresencePenalty != nil {
317 params.PresencePenalty = param.Opt[float64]{}
318 warnings = append(warnings, ai.CallWarning{
319 Type: ai.CallWarningTypeUnsupportedSetting,
320 Setting: "PresencePenalty",
321 Details: "PresencePenalty is not supported for reasoning models",
322 })
323 }
324 if providerOptions.LogitBias != nil {
325 params.LogitBias = nil
326 warnings = append(warnings, ai.CallWarning{
327 Type: ai.CallWarningTypeUnsupportedSetting,
328 Setting: "LogitBias",
329 Message: "LogitBias is not supported for reasoning models",
330 })
331 }
332 if providerOptions.LogProbs != nil {
333 params.Logprobs = param.Opt[bool]{}
334 warnings = append(warnings, ai.CallWarning{
335 Type: ai.CallWarningTypeUnsupportedSetting,
336 Setting: "Logprobs",
337 Message: "Logprobs is not supported for reasoning models",
338 })
339 }
340 if providerOptions.TopLogProbs != nil {
341 params.TopLogprobs = param.Opt[int64]{}
342 warnings = append(warnings, ai.CallWarning{
343 Type: ai.CallWarningTypeUnsupportedSetting,
344 Setting: "TopLogprobs",
345 Message: "TopLogprobs is not supported for reasoning models",
346 })
347 }
348
349 // reasoning models use max_completion_tokens instead of max_tokens
350 if call.MaxOutputTokens != nil {
351 if providerOptions.MaxCompletionTokens == nil {
352 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
353 }
354 params.MaxTokens = param.Opt[int64]{}
355 }
356 }
357
358 // Handle search preview models
359 if isSearchPreviewModel(o.modelID) {
360 if call.Temperature != nil {
361 params.Temperature = param.Opt[float64]{}
362 warnings = append(warnings, ai.CallWarning{
363 Type: ai.CallWarningTypeUnsupportedSetting,
364 Setting: "temperature",
365 Details: "temperature is not supported for the search preview models and has been removed.",
366 })
367 }
368 }
369
370 // Handle service tier validation
371 if providerOptions.ServiceTier != nil {
372 serviceTier := *providerOptions.ServiceTier
373 if serviceTier == "flex" && !supportsFlexProcessing(o.modelID) {
374 params.ServiceTier = ""
375 warnings = append(warnings, ai.CallWarning{
376 Type: ai.CallWarningTypeUnsupportedSetting,
377 Setting: "ServiceTier",
378 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
379 })
380 } else if serviceTier == "priority" && !supportsPriorityProcessing(o.modelID) {
381 params.ServiceTier = ""
382 warnings = append(warnings, ai.CallWarning{
383 Type: ai.CallWarningTypeUnsupportedSetting,
384 Setting: "ServiceTier",
385 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
386 })
387 }
388 }
389
390 if len(call.Tools) > 0 {
391 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
392 params.Tools = tools
393 if toolChoice != nil {
394 params.ToolChoice = *toolChoice
395 }
396 warnings = append(warnings, toolWarnings...)
397 }
398 return params, warnings, nil
399}
400
401func (o languageModel) handleError(err error) error {
402 var apiErr *openai.Error
403 if errors.As(err, &apiErr) {
404 requestDump := apiErr.DumpRequest(true)
405 responseDump := apiErr.DumpResponse(true)
406 headers := map[string]string{}
407 for k, h := range apiErr.Response.Header {
408 v := h[len(h)-1]
409 headers[strings.ToLower(k)] = v
410 }
411 return ai.NewAPICallError(
412 apiErr.Message,
413 apiErr.Request.URL.String(),
414 string(requestDump),
415 apiErr.StatusCode,
416 headers,
417 string(responseDump),
418 apiErr,
419 false,
420 )
421 }
422 return err
423}
424
425// Generate implements ai.LanguageModel.
426func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
427 params, warnings, err := o.prepareParams(call)
428 if err != nil {
429 return nil, err
430 }
431 response, err := o.client.Chat.Completions.New(ctx, *params)
432 if err != nil {
433 return nil, o.handleError(err)
434 }
435
436 if len(response.Choices) == 0 {
437 return nil, errors.New("no response generated")
438 }
439 choice := response.Choices[0]
440 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
441 text := choice.Message.Content
442 if text != "" {
443 content = append(content, ai.TextContent{
444 Text: text,
445 })
446 }
447
448 for _, tc := range choice.Message.ToolCalls {
449 toolCallID := tc.ID
450 if toolCallID == "" {
451 toolCallID = uuid.NewString()
452 }
453 content = append(content, ai.ToolCallContent{
454 ProviderExecuted: false, // TODO: update when handling other tools
455 ToolCallID: toolCallID,
456 ToolName: tc.Function.Name,
457 Input: tc.Function.Arguments,
458 })
459 }
460 // Handle annotations/citations
461 for _, annotation := range choice.Message.Annotations {
462 if annotation.Type == "url_citation" {
463 content = append(content, ai.SourceContent{
464 SourceType: ai.SourceTypeURL,
465 ID: uuid.NewString(),
466 URL: annotation.URLCitation.URL,
467 Title: annotation.URLCitation.Title,
468 })
469 }
470 }
471
472 completionTokenDetails := response.Usage.CompletionTokensDetails
473 promptTokenDetails := response.Usage.PromptTokensDetails
474
475 // Build provider metadata
476 providerMetadata := ai.ProviderMetadata{
477 "openai": make(map[string]any),
478 }
479
480 // Add logprobs if available
481 if len(choice.Logprobs.Content) > 0 {
482 providerMetadata["openai"]["logprobs"] = choice.Logprobs.Content
483 }
484
485 // Add prediction tokens if available
486 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
487 if completionTokenDetails.AcceptedPredictionTokens > 0 {
488 providerMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
489 }
490 if completionTokenDetails.RejectedPredictionTokens > 0 {
491 providerMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
492 }
493 }
494
495 return &ai.Response{
496 Content: content,
497 Usage: ai.Usage{
498 InputTokens: response.Usage.PromptTokens,
499 OutputTokens: response.Usage.CompletionTokens,
500 TotalTokens: response.Usage.TotalTokens,
501 ReasoningTokens: completionTokenDetails.ReasoningTokens,
502 CacheReadTokens: promptTokenDetails.CachedTokens,
503 },
504 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
505 ProviderMetadata: providerMetadata,
506 Warnings: warnings,
507 }, nil
508}
509
510type toolCall struct {
511 id string
512 name string
513 arguments string
514 hasFinished bool
515}
516
517// Stream implements ai.LanguageModel.
518func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
519 params, warnings, err := o.prepareParams(call)
520 if err != nil {
521 return nil, err
522 }
523
524 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
525 IncludeUsage: openai.Bool(true),
526 }
527
528 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
529 isActiveText := false
530 toolCalls := make(map[int64]toolCall)
531
532 // Build provider metadata for streaming
533 streamProviderMetadata := ai.ProviderMetadata{
534 "openai": make(map[string]any),
535 }
536
537 acc := openai.ChatCompletionAccumulator{}
538 var usage ai.Usage
539 return func(yield func(ai.StreamPart) bool) {
540 if len(warnings) > 0 {
541 if !yield(ai.StreamPart{
542 Type: ai.StreamPartTypeWarnings,
543 Warnings: warnings,
544 }) {
545 return
546 }
547 }
548 for stream.Next() {
549 chunk := stream.Current()
550 acc.AddChunk(chunk)
551 if chunk.Usage.TotalTokens > 0 {
552 // we do this here because the acc does not add prompt details
553 completionTokenDetails := chunk.Usage.CompletionTokensDetails
554 promptTokenDetails := chunk.Usage.PromptTokensDetails
555 usage = ai.Usage{
556 InputTokens: chunk.Usage.PromptTokens,
557 OutputTokens: chunk.Usage.CompletionTokens,
558 TotalTokens: chunk.Usage.TotalTokens,
559 ReasoningTokens: completionTokenDetails.ReasoningTokens,
560 CacheReadTokens: promptTokenDetails.CachedTokens,
561 }
562
563 // Add prediction tokens if available
564 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
565 if completionTokenDetails.AcceptedPredictionTokens > 0 {
566 streamProviderMetadata["openai"]["acceptedPredictionTokens"] = completionTokenDetails.AcceptedPredictionTokens
567 }
568 if completionTokenDetails.RejectedPredictionTokens > 0 {
569 streamProviderMetadata["openai"]["rejectedPredictionTokens"] = completionTokenDetails.RejectedPredictionTokens
570 }
571 }
572 }
573 if len(chunk.Choices) == 0 {
574 continue
575 }
576 for _, choice := range chunk.Choices {
577 switch {
578 case choice.Delta.Content != "":
579 if !isActiveText {
580 isActiveText = true
581 if !yield(ai.StreamPart{
582 Type: ai.StreamPartTypeTextStart,
583 ID: "0",
584 }) {
585 return
586 }
587 }
588 if !yield(ai.StreamPart{
589 Type: ai.StreamPartTypeTextDelta,
590 ID: "0",
591 Delta: choice.Delta.Content,
592 }) {
593 return
594 }
595 case len(choice.Delta.ToolCalls) > 0:
596 if isActiveText {
597 isActiveText = false
598 if !yield(ai.StreamPart{
599 Type: ai.StreamPartTypeTextEnd,
600 ID: "0",
601 }) {
602 return
603 }
604 }
605
606 for _, toolCallDelta := range choice.Delta.ToolCalls {
607 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
608 if existingToolCall.hasFinished {
609 continue
610 }
611 if toolCallDelta.Function.Arguments != "" {
612 existingToolCall.arguments += toolCallDelta.Function.Arguments
613 }
614 if !yield(ai.StreamPart{
615 Type: ai.StreamPartTypeToolInputDelta,
616 ID: existingToolCall.id,
617 Delta: toolCallDelta.Function.Arguments,
618 }) {
619 return
620 }
621 toolCalls[toolCallDelta.Index] = existingToolCall
622 if jsonext.IsValidJSON(existingToolCall.arguments) {
623 if !yield(ai.StreamPart{
624 Type: ai.StreamPartTypeToolInputEnd,
625 ID: existingToolCall.id,
626 }) {
627 return
628 }
629
630 if !yield(ai.StreamPart{
631 Type: ai.StreamPartTypeToolCall,
632 ID: existingToolCall.id,
633 ToolCallName: existingToolCall.name,
634 ToolCallInput: existingToolCall.arguments,
635 }) {
636 return
637 }
638 existingToolCall.hasFinished = true
639 toolCalls[toolCallDelta.Index] = existingToolCall
640 }
641 } else {
642 // Does not exist
643 var err error
644 if toolCallDelta.Type != "function" {
645 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
646 }
647 if toolCallDelta.ID == "" {
648 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
649 }
650 if toolCallDelta.Function.Name == "" {
651 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
652 }
653 if err != nil {
654 yield(ai.StreamPart{
655 Type: ai.StreamPartTypeError,
656 Error: o.handleError(stream.Err()),
657 })
658 return
659 }
660
661 if !yield(ai.StreamPart{
662 Type: ai.StreamPartTypeToolInputStart,
663 ID: toolCallDelta.ID,
664 ToolCallName: toolCallDelta.Function.Name,
665 }) {
666 return
667 }
668 toolCalls[toolCallDelta.Index] = toolCall{
669 id: toolCallDelta.ID,
670 name: toolCallDelta.Function.Name,
671 arguments: toolCallDelta.Function.Arguments,
672 }
673
674 exTc := toolCalls[toolCallDelta.Index]
675 if exTc.arguments != "" {
676 if !yield(ai.StreamPart{
677 Type: ai.StreamPartTypeToolInputDelta,
678 ID: exTc.id,
679 Delta: exTc.arguments,
680 }) {
681 return
682 }
683 if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
684 if !yield(ai.StreamPart{
685 Type: ai.StreamPartTypeToolInputEnd,
686 ID: toolCallDelta.ID,
687 }) {
688 return
689 }
690
691 if !yield(ai.StreamPart{
692 Type: ai.StreamPartTypeToolCall,
693 ID: exTc.id,
694 ToolCallName: exTc.name,
695 ToolCallInput: exTc.arguments,
696 }) {
697 return
698 }
699 exTc.hasFinished = true
700 toolCalls[toolCallDelta.Index] = exTc
701 }
702 }
703 continue
704 }
705 }
706 }
707 }
708
709 // Check for annotations in the delta's raw JSON
710 for _, choice := range chunk.Choices {
711 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
712 for _, annotation := range annotations {
713 if annotation.Type == "url_citation" {
714 if !yield(ai.StreamPart{
715 Type: ai.StreamPartTypeSource,
716 ID: uuid.NewString(),
717 SourceType: ai.SourceTypeURL,
718 URL: annotation.URLCitation.URL,
719 Title: annotation.URLCitation.Title,
720 }) {
721 return
722 }
723 }
724 }
725 }
726 }
727 }
728 err := stream.Err()
729 if err == nil || errors.Is(err, io.EOF) {
730 // finished
731 if isActiveText {
732 isActiveText = false
733 if !yield(ai.StreamPart{
734 Type: ai.StreamPartTypeTextEnd,
735 ID: "0",
736 }) {
737 return
738 }
739 }
740
741 // Add logprobs if available
742 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
743 streamProviderMetadata["openai"]["logprobs"] = acc.Choices[0].Logprobs.Content
744 }
745
746 // Handle annotations/citations from accumulated response
747 if len(acc.Choices) > 0 {
748 for _, annotation := range acc.Choices[0].Message.Annotations {
749 if annotation.Type == "url_citation" {
750 if !yield(ai.StreamPart{
751 Type: ai.StreamPartTypeSource,
752 ID: acc.ID,
753 SourceType: ai.SourceTypeURL,
754 URL: annotation.URLCitation.URL,
755 Title: annotation.URLCitation.Title,
756 }) {
757 return
758 }
759 }
760 }
761 }
762
763 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
764 yield(ai.StreamPart{
765 Type: ai.StreamPartTypeFinish,
766 Usage: usage,
767 FinishReason: finishReason,
768 ProviderMetadata: streamProviderMetadata,
769 })
770 return
771 } else {
772 yield(ai.StreamPart{
773 Type: ai.StreamPartTypeError,
774 Error: o.handleError(err),
775 })
776 return
777 }
778 }, nil
779}
780
781func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
782 switch finishReason {
783 case "stop":
784 return ai.FinishReasonStop
785 case "length":
786 return ai.FinishReasonLength
787 case "content_filter":
788 return ai.FinishReasonContentFilter
789 case "function_call", "tool_calls":
790 return ai.FinishReasonToolCalls
791 default:
792 return ai.FinishReasonUnknown
793 }
794}
795
796func isReasoningModel(modelID string) bool {
797 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
798}
799
800func isSearchPreviewModel(modelID string) bool {
801 return strings.Contains(modelID, "search-preview")
802}
803
804func supportsFlexProcessing(modelID string) bool {
805 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
806}
807
808func supportsPriorityProcessing(modelID string) bool {
809 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
810 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
811 strings.HasPrefix(modelID, "o4-mini")
812}
813
814func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
815 for _, tool := range tools {
816 if tool.GetType() == ai.ToolTypeFunction {
817 ft, ok := tool.(ai.FunctionTool)
818 if !ok {
819 continue
820 }
821 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
822 OfFunction: &openai.ChatCompletionFunctionToolParam{
823 Function: shared.FunctionDefinitionParam{
824 Name: ft.Name,
825 Description: param.NewOpt(ft.Description),
826 Parameters: openai.FunctionParameters(ft.InputSchema),
827 Strict: param.NewOpt(false),
828 },
829 Type: "function",
830 },
831 })
832 continue
833 }
834
835 // TODO: handle provider tool calls
836 warnings = append(warnings, ai.CallWarning{
837 Type: ai.CallWarningTypeUnsupportedTool,
838 Tool: tool,
839 Message: "tool is not supported",
840 })
841 }
842 if toolChoice == nil {
843 return openAiTools, openAiToolChoice, warnings
844 }
845
846 switch *toolChoice {
847 case ai.ToolChoiceAuto:
848 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
849 OfAuto: param.NewOpt("auto"),
850 }
851 case ai.ToolChoiceNone:
852 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
853 OfAuto: param.NewOpt("none"),
854 }
855 default:
856 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
857 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
858 Type: "function",
859 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
860 Name: string(*toolChoice),
861 },
862 },
863 }
864 }
865 return openAiTools, openAiToolChoice, warnings
866}
867
868func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
869 var messages []openai.ChatCompletionMessageParamUnion
870 var warnings []ai.CallWarning
871 for _, msg := range prompt {
872 switch msg.Role {
873 case ai.MessageRoleSystem:
874 var systemPromptParts []string
875 for _, c := range msg.Content {
876 if c.GetType() != ai.ContentTypeText {
877 warnings = append(warnings, ai.CallWarning{
878 Type: ai.CallWarningTypeOther,
879 Message: "system prompt can only have text content",
880 })
881 continue
882 }
883 textPart, ok := ai.AsContentType[ai.TextPart](c)
884 if !ok {
885 warnings = append(warnings, ai.CallWarning{
886 Type: ai.CallWarningTypeOther,
887 Message: "system prompt text part does not have the right type",
888 })
889 continue
890 }
891 text := textPart.Text
892 if strings.TrimSpace(text) != "" {
893 systemPromptParts = append(systemPromptParts, textPart.Text)
894 }
895 }
896 if len(systemPromptParts) == 0 {
897 warnings = append(warnings, ai.CallWarning{
898 Type: ai.CallWarningTypeOther,
899 Message: "system prompt has no text parts",
900 })
901 continue
902 }
903 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
904 case ai.MessageRoleUser:
905 // simple user message just text content
906 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
907 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
908 if !ok {
909 warnings = append(warnings, ai.CallWarning{
910 Type: ai.CallWarningTypeOther,
911 Message: "user message text part does not have the right type",
912 })
913 continue
914 }
915 messages = append(messages, openai.UserMessage(textPart.Text))
916 continue
917 }
918 // text content and attachments
919 // for now we only support image content later we need to check
920 // TODO: add the supported media types to the language model so we
921 // can use that to validate the data here.
922 var content []openai.ChatCompletionContentPartUnionParam
923 for _, c := range msg.Content {
924 switch c.GetType() {
925 case ai.ContentTypeText:
926 textPart, ok := ai.AsContentType[ai.TextPart](c)
927 if !ok {
928 warnings = append(warnings, ai.CallWarning{
929 Type: ai.CallWarningTypeOther,
930 Message: "user message text part does not have the right type",
931 })
932 continue
933 }
934 content = append(content, openai.ChatCompletionContentPartUnionParam{
935 OfText: &openai.ChatCompletionContentPartTextParam{
936 Text: textPart.Text,
937 },
938 })
939 case ai.ContentTypeFile:
940 filePart, ok := ai.AsContentType[ai.FilePart](c)
941 if !ok {
942 warnings = append(warnings, ai.CallWarning{
943 Type: ai.CallWarningTypeOther,
944 Message: "user message file part does not have the right type",
945 })
946 continue
947 }
948
949 switch {
950 case strings.HasPrefix(filePart.MediaType, "image/"):
951 // Handle image files
952 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
953 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
954 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
955
956 // Check for provider-specific options like image detail
957 if providerOptions, ok := filePart.ProviderOptions["openai"]; ok {
958 if detail, ok := providerOptions["imageDetail"].(string); ok {
959 imageURL.Detail = detail
960 }
961 }
962
963 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
964 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
965
966 case filePart.MediaType == "audio/wav":
967 // Handle WAV audio files
968 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
969 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
970 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
971 Data: base64Encoded,
972 Format: "wav",
973 },
974 }
975 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
976
977 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
978 // Handle MP3 audio files
979 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
980 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
981 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
982 Data: base64Encoded,
983 Format: "mp3",
984 },
985 }
986 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
987
988 case filePart.MediaType == "application/pdf":
989 // Handle PDF files
990 dataStr := string(filePart.Data)
991
992 // Check if data looks like a file ID (starts with "file-")
993 if strings.HasPrefix(dataStr, "file-") {
994 fileBlock := openai.ChatCompletionContentPartFileParam{
995 File: openai.ChatCompletionContentPartFileFileParam{
996 FileID: param.NewOpt(dataStr),
997 },
998 }
999 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1000 } else {
1001 // Handle as base64 data
1002 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1003 data := "data:application/pdf;base64," + base64Encoded
1004
1005 filename := filePart.Filename
1006 if filename == "" {
1007 // Generate default filename based on content index
1008 filename = fmt.Sprintf("part-%d.pdf", len(content))
1009 }
1010
1011 fileBlock := openai.ChatCompletionContentPartFileParam{
1012 File: openai.ChatCompletionContentPartFileFileParam{
1013 Filename: param.NewOpt(filename),
1014 FileData: param.NewOpt(data),
1015 },
1016 }
1017 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1018 }
1019
1020 default:
1021 warnings = append(warnings, ai.CallWarning{
1022 Type: ai.CallWarningTypeOther,
1023 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1024 })
1025 }
1026 }
1027 }
1028 messages = append(messages, openai.UserMessage(content))
1029 case ai.MessageRoleAssistant:
1030 // simple assistant message just text content
1031 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1032 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1033 if !ok {
1034 warnings = append(warnings, ai.CallWarning{
1035 Type: ai.CallWarningTypeOther,
1036 Message: "assistant message text part does not have the right type",
1037 })
1038 continue
1039 }
1040 messages = append(messages, openai.AssistantMessage(textPart.Text))
1041 continue
1042 }
1043 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1044 Role: "assistant",
1045 }
1046 for _, c := range msg.Content {
1047 switch c.GetType() {
1048 case ai.ContentTypeText:
1049 textPart, ok := ai.AsContentType[ai.TextPart](c)
1050 if !ok {
1051 warnings = append(warnings, ai.CallWarning{
1052 Type: ai.CallWarningTypeOther,
1053 Message: "assistant message text part does not have the right type",
1054 })
1055 continue
1056 }
1057 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1058 OfString: param.NewOpt(textPart.Text),
1059 }
1060 case ai.ContentTypeToolCall:
1061 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1062 if !ok {
1063 warnings = append(warnings, ai.CallWarning{
1064 Type: ai.CallWarningTypeOther,
1065 Message: "assistant message tool part does not have the right type",
1066 })
1067 continue
1068 }
1069 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1070 openai.ChatCompletionMessageToolCallUnionParam{
1071 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1072 ID: toolCallPart.ToolCallID,
1073 Type: "function",
1074 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1075 Name: toolCallPart.ToolName,
1076 Arguments: toolCallPart.Input,
1077 },
1078 },
1079 })
1080 }
1081 }
1082 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1083 OfAssistant: &assistantMsg,
1084 })
1085 case ai.MessageRoleTool:
1086 for _, c := range msg.Content {
1087 if c.GetType() != ai.ContentTypeToolResult {
1088 warnings = append(warnings, ai.CallWarning{
1089 Type: ai.CallWarningTypeOther,
1090 Message: "tool message can only have tool result content",
1091 })
1092 continue
1093 }
1094
1095 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1096 if !ok {
1097 warnings = append(warnings, ai.CallWarning{
1098 Type: ai.CallWarningTypeOther,
1099 Message: "tool message result part does not have the right type",
1100 })
1101 continue
1102 }
1103
1104 switch toolResultPart.Output.GetType() {
1105 case ai.ToolResultContentTypeText:
1106 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1107 if !ok {
1108 warnings = append(warnings, ai.CallWarning{
1109 Type: ai.CallWarningTypeOther,
1110 Message: "tool result output does not have the right type",
1111 })
1112 continue
1113 }
1114 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1115 case ai.ToolResultContentTypeError:
1116 // TODO: check if better handling is needed
1117 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1118 if !ok {
1119 warnings = append(warnings, ai.CallWarning{
1120 Type: ai.CallWarningTypeOther,
1121 Message: "tool result output does not have the right type",
1122 })
1123 continue
1124 }
1125 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1126 }
1127 }
1128 }
1129 }
1130 return messages, warnings
1131}
1132
1133// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1134func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1135 var annotations []openai.ChatCompletionMessageAnnotation
1136
1137 // Parse the raw JSON to extract annotations
1138 var deltaData map[string]any
1139 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1140 return annotations
1141 }
1142
1143 // Check if annotations exist in the delta
1144 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1145 for _, annotationData := range annotationsData {
1146 if annotationMap, ok := annotationData.(map[string]any); ok {
1147 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1148 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1149 annotation := openai.ChatCompletionMessageAnnotation{
1150 Type: "url_citation",
1151 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1152 URL: urlCitationData["url"].(string),
1153 Title: urlCitationData["title"].(string),
1154 },
1155 }
1156 annotations = append(annotations, annotation)
1157 }
1158 }
1159 }
1160 }
1161 }
1162
1163 return annotations
1164}