1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23const (
24 Name = "openai"
25 DefaultURL = "https://api.openai.com/v1"
26)
27
28type provider struct {
29 options options
30}
31
32type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
33
34type Hooks struct {
35 PrepareCallWithOptions PrepareCallWithOptions
36}
37
38type options struct {
39 baseURL string
40 apiKey string
41 organization string
42 project string
43 name string
44 hooks Hooks
45 headers map[string]string
46 client option.HTTPClient
47}
48
49type Option = func(*options)
50
51func New(opts ...Option) ai.Provider {
52 providerOptions := options{
53 headers: map[string]string{},
54 }
55 for _, o := range opts {
56 o(&providerOptions)
57 }
58
59 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
60 providerOptions.name = cmp.Or(providerOptions.name, Name)
61
62 if providerOptions.organization != "" {
63 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
64 }
65 if providerOptions.project != "" {
66 providerOptions.headers["OpenAi-Project"] = providerOptions.project
67 }
68
69 return &provider{options: providerOptions}
70}
71
72func WithBaseURL(baseURL string) Option {
73 return func(o *options) {
74 o.baseURL = baseURL
75 }
76}
77
78func WithAPIKey(apiKey string) Option {
79 return func(o *options) {
80 o.apiKey = apiKey
81 }
82}
83
84func WithOrganization(organization string) Option {
85 return func(o *options) {
86 o.organization = organization
87 }
88}
89
90func WithProject(project string) Option {
91 return func(o *options) {
92 o.project = project
93 }
94}
95
96func WithName(name string) Option {
97 return func(o *options) {
98 o.name = name
99 }
100}
101
102func WithHeaders(headers map[string]string) Option {
103 return func(o *options) {
104 maps.Copy(o.headers, headers)
105 }
106}
107
108func WithHTTPClient(client option.HTTPClient) Option {
109 return func(o *options) {
110 o.client = client
111 }
112}
113
114func WithHooks(hooks Hooks) Option {
115 return func(o *options) {
116 o.hooks = hooks
117 }
118}
119
120// LanguageModel implements ai.Provider.
121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122 openaiClientOptions := []option.RequestOption{}
123 if o.options.apiKey != "" {
124 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
125 }
126 if o.options.baseURL != "" {
127 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
128 }
129
130 for key, value := range o.options.headers {
131 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
132 }
133
134 if o.options.client != nil {
135 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
136 }
137
138 return languageModel{
139 modelID: modelID,
140 provider: o.options.name,
141 options: o.options,
142 client: openai.NewClient(openaiClientOptions...),
143 }, nil
144}
145
146type languageModel struct {
147 provider string
148 modelID string
149 client openai.Client
150 options options
151}
152
153// Model implements ai.LanguageModel.
154func (o languageModel) Model() string {
155 return o.modelID
156}
157
158// Provider implements ai.LanguageModel.
159func (o languageModel) Provider() string {
160 return o.provider
161}
162
163func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
164 if call.ProviderOptions == nil {
165 return nil, nil
166 }
167 var warnings []ai.CallWarning
168 providerOptions := &ProviderOptions{}
169 if v, ok := call.ProviderOptions[Name]; ok {
170 providerOptions, ok = v.(*ProviderOptions)
171 if !ok {
172 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
173 }
174 }
175
176 if providerOptions.LogitBias != nil {
177 params.LogitBias = providerOptions.LogitBias
178 }
179 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
180 providerOptions.LogProbs = nil
181 }
182 if providerOptions.LogProbs != nil {
183 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
184 }
185 if providerOptions.TopLogProbs != nil {
186 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
187 }
188 if providerOptions.User != nil {
189 params.User = param.NewOpt(*providerOptions.User)
190 }
191 if providerOptions.ParallelToolCalls != nil {
192 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
193 }
194 if providerOptions.MaxCompletionTokens != nil {
195 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
196 }
197
198 if providerOptions.TextVerbosity != nil {
199 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
200 }
201 if providerOptions.Prediction != nil {
202 // Convert map[string]any to ChatCompletionPredictionContentParam
203 if content, ok := providerOptions.Prediction["content"]; ok {
204 if contentStr, ok := content.(string); ok {
205 params.Prediction = openai.ChatCompletionPredictionContentParam{
206 Content: openai.ChatCompletionPredictionContentContentUnionParam{
207 OfString: param.NewOpt(contentStr),
208 },
209 }
210 }
211 }
212 }
213 if providerOptions.Store != nil {
214 params.Store = param.NewOpt(*providerOptions.Store)
215 }
216 if providerOptions.Metadata != nil {
217 // Convert map[string]any to map[string]string
218 metadata := make(map[string]string)
219 for k, v := range providerOptions.Metadata {
220 if str, ok := v.(string); ok {
221 metadata[k] = str
222 }
223 }
224 params.Metadata = metadata
225 }
226 if providerOptions.PromptCacheKey != nil {
227 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
228 }
229 if providerOptions.SafetyIdentifier != nil {
230 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
231 }
232 if providerOptions.ServiceTier != nil {
233 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
234 }
235
236 if providerOptions.ReasoningEffort != nil {
237 switch *providerOptions.ReasoningEffort {
238 case ReasoningEffortMinimal:
239 params.ReasoningEffort = shared.ReasoningEffortMinimal
240 case ReasoningEffortLow:
241 params.ReasoningEffort = shared.ReasoningEffortLow
242 case ReasoningEffortMedium:
243 params.ReasoningEffort = shared.ReasoningEffortMedium
244 case ReasoningEffortHigh:
245 params.ReasoningEffort = shared.ReasoningEffortHigh
246 default:
247 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
248 }
249 }
250
251 if isReasoningModel(model.Model()) {
252 if providerOptions.LogitBias != nil {
253 params.LogitBias = nil
254 warnings = append(warnings, ai.CallWarning{
255 Type: ai.CallWarningTypeUnsupportedSetting,
256 Setting: "LogitBias",
257 Message: "LogitBias is not supported for reasoning models",
258 })
259 }
260 if providerOptions.LogProbs != nil {
261 params.Logprobs = param.Opt[bool]{}
262 warnings = append(warnings, ai.CallWarning{
263 Type: ai.CallWarningTypeUnsupportedSetting,
264 Setting: "Logprobs",
265 Message: "Logprobs is not supported for reasoning models",
266 })
267 }
268 if providerOptions.TopLogProbs != nil {
269 params.TopLogprobs = param.Opt[int64]{}
270 warnings = append(warnings, ai.CallWarning{
271 Type: ai.CallWarningTypeUnsupportedSetting,
272 Setting: "TopLogprobs",
273 Message: "TopLogprobs is not supported for reasoning models",
274 })
275 }
276
277 }
278
279 // Handle service tier validation
280 if providerOptions.ServiceTier != nil {
281 serviceTier := *providerOptions.ServiceTier
282 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
283 params.ServiceTier = ""
284 warnings = append(warnings, ai.CallWarning{
285 Type: ai.CallWarningTypeUnsupportedSetting,
286 Setting: "ServiceTier",
287 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
288 })
289 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
290 params.ServiceTier = ""
291 warnings = append(warnings, ai.CallWarning{
292 Type: ai.CallWarningTypeUnsupportedSetting,
293 Setting: "ServiceTier",
294 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
295 })
296 }
297 }
298 return warnings, nil
299}
300
301func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
302 params := &openai.ChatCompletionNewParams{}
303 messages, warnings := toPrompt(call.Prompt)
304 if call.TopK != nil {
305 warnings = append(warnings, ai.CallWarning{
306 Type: ai.CallWarningTypeUnsupportedSetting,
307 Setting: "top_k",
308 })
309 }
310 params.Messages = messages
311 params.Model = o.modelID
312
313 if call.MaxOutputTokens != nil {
314 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
315 }
316 if call.Temperature != nil {
317 params.Temperature = param.NewOpt(*call.Temperature)
318 }
319 if call.TopP != nil {
320 params.TopP = param.NewOpt(*call.TopP)
321 }
322 if call.FrequencyPenalty != nil {
323 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
324 }
325 if call.PresencePenalty != nil {
326 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
327 }
328
329 if isReasoningModel(o.modelID) {
330 // remove unsupported settings for reasoning models
331 // see https://platform.openai.com/docs/guides/reasoning#limitations
332 if call.Temperature != nil {
333 params.Temperature = param.Opt[float64]{}
334 warnings = append(warnings, ai.CallWarning{
335 Type: ai.CallWarningTypeUnsupportedSetting,
336 Setting: "temperature",
337 Details: "temperature is not supported for reasoning models",
338 })
339 }
340 if call.TopP != nil {
341 params.TopP = param.Opt[float64]{}
342 warnings = append(warnings, ai.CallWarning{
343 Type: ai.CallWarningTypeUnsupportedSetting,
344 Setting: "TopP",
345 Details: "TopP is not supported for reasoning models",
346 })
347 }
348 if call.FrequencyPenalty != nil {
349 params.FrequencyPenalty = param.Opt[float64]{}
350 warnings = append(warnings, ai.CallWarning{
351 Type: ai.CallWarningTypeUnsupportedSetting,
352 Setting: "FrequencyPenalty",
353 Details: "FrequencyPenalty is not supported for reasoning models",
354 })
355 }
356 if call.PresencePenalty != nil {
357 params.PresencePenalty = param.Opt[float64]{}
358 warnings = append(warnings, ai.CallWarning{
359 Type: ai.CallWarningTypeUnsupportedSetting,
360 Setting: "PresencePenalty",
361 Details: "PresencePenalty is not supported for reasoning models",
362 })
363 }
364
365 // reasoning models use max_completion_tokens instead of max_tokens
366 if call.MaxOutputTokens != nil {
367 if !params.MaxCompletionTokens.Valid() {
368 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
369 }
370 params.MaxTokens = param.Opt[int64]{}
371 }
372 }
373
374 // Handle search preview models
375 if isSearchPreviewModel(o.modelID) {
376 if call.Temperature != nil {
377 params.Temperature = param.Opt[float64]{}
378 warnings = append(warnings, ai.CallWarning{
379 Type: ai.CallWarningTypeUnsupportedSetting,
380 Setting: "temperature",
381 Details: "temperature is not supported for the search preview models and has been removed.",
382 })
383 }
384 }
385
386 prepareOptions := prepareCallWithOptions
387 if o.options.hooks.PrepareCallWithOptions != nil {
388 prepareOptions = o.options.hooks.PrepareCallWithOptions
389 }
390
391 optionsWarnings, err := prepareOptions(o, params, call)
392 if err != nil {
393 return nil, nil, err
394 }
395
396 if len(optionsWarnings) > 0 {
397 warnings = append(warnings, optionsWarnings...)
398 }
399
400 if len(call.Tools) > 0 {
401 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
402 params.Tools = tools
403 if toolChoice != nil {
404 params.ToolChoice = *toolChoice
405 }
406 warnings = append(warnings, toolWarnings...)
407 }
408 return params, warnings, nil
409}
410
411func (o languageModel) handleError(err error) error {
412 var apiErr *openai.Error
413 if errors.As(err, &apiErr) {
414 requestDump := apiErr.DumpRequest(true)
415 responseDump := apiErr.DumpResponse(true)
416 headers := map[string]string{}
417 for k, h := range apiErr.Response.Header {
418 v := h[len(h)-1]
419 headers[strings.ToLower(k)] = v
420 }
421 return ai.NewAPICallError(
422 apiErr.Message,
423 apiErr.Request.URL.String(),
424 string(requestDump),
425 apiErr.StatusCode,
426 headers,
427 string(responseDump),
428 apiErr,
429 false,
430 )
431 }
432 return err
433}
434
435// Generate implements ai.LanguageModel.
436func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
437 params, warnings, err := o.prepareParams(call)
438 if err != nil {
439 return nil, err
440 }
441 response, err := o.client.Chat.Completions.New(ctx, *params)
442 if err != nil {
443 return nil, o.handleError(err)
444 }
445
446 if len(response.Choices) == 0 {
447 return nil, errors.New("no response generated")
448 }
449 choice := response.Choices[0]
450 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
451 text := choice.Message.Content
452 if text != "" {
453 content = append(content, ai.TextContent{
454 Text: text,
455 })
456 }
457
458 for _, tc := range choice.Message.ToolCalls {
459 toolCallID := tc.ID
460 if toolCallID == "" {
461 toolCallID = uuid.NewString()
462 }
463 content = append(content, ai.ToolCallContent{
464 ProviderExecuted: false, // TODO: update when handling other tools
465 ToolCallID: toolCallID,
466 ToolName: tc.Function.Name,
467 Input: tc.Function.Arguments,
468 })
469 }
470 // Handle annotations/citations
471 for _, annotation := range choice.Message.Annotations {
472 if annotation.Type == "url_citation" {
473 content = append(content, ai.SourceContent{
474 SourceType: ai.SourceTypeURL,
475 ID: uuid.NewString(),
476 URL: annotation.URLCitation.URL,
477 Title: annotation.URLCitation.Title,
478 })
479 }
480 }
481
482 completionTokenDetails := response.Usage.CompletionTokensDetails
483 promptTokenDetails := response.Usage.PromptTokensDetails
484
485 // Build provider metadata
486 providerMetadata := &ProviderMetadata{}
487 // Add logprobs if available
488 if len(choice.Logprobs.Content) > 0 {
489 providerMetadata.Logprobs = choice.Logprobs.Content
490 }
491
492 // Add prediction tokens if available
493 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
494 if completionTokenDetails.AcceptedPredictionTokens > 0 {
495 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
496 }
497 if completionTokenDetails.RejectedPredictionTokens > 0 {
498 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
499 }
500 }
501
502 return &ai.Response{
503 Content: content,
504 Usage: ai.Usage{
505 InputTokens: response.Usage.PromptTokens,
506 OutputTokens: response.Usage.CompletionTokens,
507 TotalTokens: response.Usage.TotalTokens,
508 ReasoningTokens: completionTokenDetails.ReasoningTokens,
509 CacheReadTokens: promptTokenDetails.CachedTokens,
510 },
511 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
512 ProviderMetadata: ai.ProviderMetadata{
513 Name: providerMetadata,
514 },
515 Warnings: warnings,
516 }, nil
517}
518
519type toolCall struct {
520 id string
521 name string
522 arguments string
523 hasFinished bool
524}
525
526// Stream implements ai.LanguageModel.
527func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
528 params, warnings, err := o.prepareParams(call)
529 if err != nil {
530 return nil, err
531 }
532
533 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
534 IncludeUsage: openai.Bool(true),
535 }
536
537 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
538 isActiveText := false
539 toolCalls := make(map[int64]toolCall)
540
541 // Build provider metadata for streaming
542 streamProviderMetadata := &ProviderMetadata{}
543 acc := openai.ChatCompletionAccumulator{}
544 var usage ai.Usage
545 return func(yield func(ai.StreamPart) bool) {
546 if len(warnings) > 0 {
547 if !yield(ai.StreamPart{
548 Type: ai.StreamPartTypeWarnings,
549 Warnings: warnings,
550 }) {
551 return
552 }
553 }
554 for stream.Next() {
555 chunk := stream.Current()
556 acc.AddChunk(chunk)
557 if chunk.Usage.TotalTokens > 0 {
558 // we do this here because the acc does not add prompt details
559 completionTokenDetails := chunk.Usage.CompletionTokensDetails
560 promptTokenDetails := chunk.Usage.PromptTokensDetails
561 usage = ai.Usage{
562 InputTokens: chunk.Usage.PromptTokens,
563 OutputTokens: chunk.Usage.CompletionTokens,
564 TotalTokens: chunk.Usage.TotalTokens,
565 ReasoningTokens: completionTokenDetails.ReasoningTokens,
566 CacheReadTokens: promptTokenDetails.CachedTokens,
567 }
568
569 // Add prediction tokens if available
570 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
571 if completionTokenDetails.AcceptedPredictionTokens > 0 {
572 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
573 }
574 if completionTokenDetails.RejectedPredictionTokens > 0 {
575 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
576 }
577 }
578 }
579 if len(chunk.Choices) == 0 {
580 continue
581 }
582 for _, choice := range chunk.Choices {
583 switch {
584 case choice.Delta.Content != "":
585 if !isActiveText {
586 isActiveText = true
587 if !yield(ai.StreamPart{
588 Type: ai.StreamPartTypeTextStart,
589 ID: "0",
590 }) {
591 return
592 }
593 }
594 if !yield(ai.StreamPart{
595 Type: ai.StreamPartTypeTextDelta,
596 ID: "0",
597 Delta: choice.Delta.Content,
598 }) {
599 return
600 }
601 case len(choice.Delta.ToolCalls) > 0:
602 if isActiveText {
603 isActiveText = false
604 if !yield(ai.StreamPart{
605 Type: ai.StreamPartTypeTextEnd,
606 ID: "0",
607 }) {
608 return
609 }
610 }
611
612 for _, toolCallDelta := range choice.Delta.ToolCalls {
613 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
614 if existingToolCall.hasFinished {
615 continue
616 }
617 if toolCallDelta.Function.Arguments != "" {
618 existingToolCall.arguments += toolCallDelta.Function.Arguments
619 }
620 if !yield(ai.StreamPart{
621 Type: ai.StreamPartTypeToolInputDelta,
622 ID: existingToolCall.id,
623 Delta: toolCallDelta.Function.Arguments,
624 }) {
625 return
626 }
627 toolCalls[toolCallDelta.Index] = existingToolCall
628 if xjson.IsValid(existingToolCall.arguments) {
629 if !yield(ai.StreamPart{
630 Type: ai.StreamPartTypeToolInputEnd,
631 ID: existingToolCall.id,
632 }) {
633 return
634 }
635
636 if !yield(ai.StreamPart{
637 Type: ai.StreamPartTypeToolCall,
638 ID: existingToolCall.id,
639 ToolCallName: existingToolCall.name,
640 ToolCallInput: existingToolCall.arguments,
641 }) {
642 return
643 }
644 existingToolCall.hasFinished = true
645 toolCalls[toolCallDelta.Index] = existingToolCall
646 }
647 } else {
648 // Does not exist
649 var err error
650 if toolCallDelta.Type != "function" {
651 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
652 }
653 if toolCallDelta.ID == "" {
654 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
655 }
656 if toolCallDelta.Function.Name == "" {
657 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
658 }
659 if err != nil {
660 yield(ai.StreamPart{
661 Type: ai.StreamPartTypeError,
662 Error: o.handleError(stream.Err()),
663 })
664 return
665 }
666
667 if !yield(ai.StreamPart{
668 Type: ai.StreamPartTypeToolInputStart,
669 ID: toolCallDelta.ID,
670 ToolCallName: toolCallDelta.Function.Name,
671 }) {
672 return
673 }
674 toolCalls[toolCallDelta.Index] = toolCall{
675 id: toolCallDelta.ID,
676 name: toolCallDelta.Function.Name,
677 arguments: toolCallDelta.Function.Arguments,
678 }
679
680 exTc := toolCalls[toolCallDelta.Index]
681 if exTc.arguments != "" {
682 if !yield(ai.StreamPart{
683 Type: ai.StreamPartTypeToolInputDelta,
684 ID: exTc.id,
685 Delta: exTc.arguments,
686 }) {
687 return
688 }
689 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
690 if !yield(ai.StreamPart{
691 Type: ai.StreamPartTypeToolInputEnd,
692 ID: toolCallDelta.ID,
693 }) {
694 return
695 }
696
697 if !yield(ai.StreamPart{
698 Type: ai.StreamPartTypeToolCall,
699 ID: exTc.id,
700 ToolCallName: exTc.name,
701 ToolCallInput: exTc.arguments,
702 }) {
703 return
704 }
705 exTc.hasFinished = true
706 toolCalls[toolCallDelta.Index] = exTc
707 }
708 }
709 continue
710 }
711 }
712 }
713 }
714
715 // Check for annotations in the delta's raw JSON
716 for _, choice := range chunk.Choices {
717 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
718 for _, annotation := range annotations {
719 if annotation.Type == "url_citation" {
720 if !yield(ai.StreamPart{
721 Type: ai.StreamPartTypeSource,
722 ID: uuid.NewString(),
723 SourceType: ai.SourceTypeURL,
724 URL: annotation.URLCitation.URL,
725 Title: annotation.URLCitation.Title,
726 }) {
727 return
728 }
729 }
730 }
731 }
732 }
733 }
734 err := stream.Err()
735 if err == nil || errors.Is(err, io.EOF) {
736 // finished
737 if isActiveText {
738 isActiveText = false
739 if !yield(ai.StreamPart{
740 Type: ai.StreamPartTypeTextEnd,
741 ID: "0",
742 }) {
743 return
744 }
745 }
746
747 // Add logprobs if available
748 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
749 streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
750 }
751
752 // Handle annotations/citations from accumulated response
753 if len(acc.Choices) > 0 {
754 for _, annotation := range acc.Choices[0].Message.Annotations {
755 if annotation.Type == "url_citation" {
756 if !yield(ai.StreamPart{
757 Type: ai.StreamPartTypeSource,
758 ID: acc.ID,
759 SourceType: ai.SourceTypeURL,
760 URL: annotation.URLCitation.URL,
761 Title: annotation.URLCitation.Title,
762 }) {
763 return
764 }
765 }
766 }
767 }
768
769 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
770 yield(ai.StreamPart{
771 Type: ai.StreamPartTypeFinish,
772 Usage: usage,
773 FinishReason: finishReason,
774 ProviderMetadata: ai.ProviderMetadata{
775 Name: streamProviderMetadata,
776 },
777 })
778 return
779 } else {
780 yield(ai.StreamPart{
781 Type: ai.StreamPartTypeError,
782 Error: o.handleError(err),
783 })
784 return
785 }
786 }, nil
787}
788
789func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
790 var options ProviderOptions
791 if err := ai.ParseOptions(data, &options); err != nil {
792 return nil, err
793 }
794 return &options, nil
795}
796
797func (o *provider) Name() string {
798 return Name
799}
800
801func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
802 switch finishReason {
803 case "stop":
804 return ai.FinishReasonStop
805 case "length":
806 return ai.FinishReasonLength
807 case "content_filter":
808 return ai.FinishReasonContentFilter
809 case "function_call", "tool_calls":
810 return ai.FinishReasonToolCalls
811 default:
812 return ai.FinishReasonUnknown
813 }
814}
815
816func isReasoningModel(modelID string) bool {
817 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
818}
819
820func isSearchPreviewModel(modelID string) bool {
821 return strings.Contains(modelID, "search-preview")
822}
823
824func supportsFlexProcessing(modelID string) bool {
825 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
826}
827
828func supportsPriorityProcessing(modelID string) bool {
829 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
830 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
831 strings.HasPrefix(modelID, "o4-mini")
832}
833
834func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
835 for _, tool := range tools {
836 if tool.GetType() == ai.ToolTypeFunction {
837 ft, ok := tool.(ai.FunctionTool)
838 if !ok {
839 continue
840 }
841 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
842 OfFunction: &openai.ChatCompletionFunctionToolParam{
843 Function: shared.FunctionDefinitionParam{
844 Name: ft.Name,
845 Description: param.NewOpt(ft.Description),
846 Parameters: openai.FunctionParameters(ft.InputSchema),
847 Strict: param.NewOpt(false),
848 },
849 Type: "function",
850 },
851 })
852 continue
853 }
854
855 // TODO: handle provider tool calls
856 warnings = append(warnings, ai.CallWarning{
857 Type: ai.CallWarningTypeUnsupportedTool,
858 Tool: tool,
859 Message: "tool is not supported",
860 })
861 }
862 if toolChoice == nil {
863 return openAiTools, openAiToolChoice, warnings
864 }
865
866 switch *toolChoice {
867 case ai.ToolChoiceAuto:
868 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
869 OfAuto: param.NewOpt("auto"),
870 }
871 case ai.ToolChoiceNone:
872 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
873 OfAuto: param.NewOpt("none"),
874 }
875 default:
876 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
877 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
878 Type: "function",
879 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
880 Name: string(*toolChoice),
881 },
882 },
883 }
884 }
885 return openAiTools, openAiToolChoice, warnings
886}
887
888func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
889 var messages []openai.ChatCompletionMessageParamUnion
890 var warnings []ai.CallWarning
891 for _, msg := range prompt {
892 switch msg.Role {
893 case ai.MessageRoleSystem:
894 var systemPromptParts []string
895 for _, c := range msg.Content {
896 if c.GetType() != ai.ContentTypeText {
897 warnings = append(warnings, ai.CallWarning{
898 Type: ai.CallWarningTypeOther,
899 Message: "system prompt can only have text content",
900 })
901 continue
902 }
903 textPart, ok := ai.AsContentType[ai.TextPart](c)
904 if !ok {
905 warnings = append(warnings, ai.CallWarning{
906 Type: ai.CallWarningTypeOther,
907 Message: "system prompt text part does not have the right type",
908 })
909 continue
910 }
911 text := textPart.Text
912 if strings.TrimSpace(text) != "" {
913 systemPromptParts = append(systemPromptParts, textPart.Text)
914 }
915 }
916 if len(systemPromptParts) == 0 {
917 warnings = append(warnings, ai.CallWarning{
918 Type: ai.CallWarningTypeOther,
919 Message: "system prompt has no text parts",
920 })
921 continue
922 }
923 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
924 case ai.MessageRoleUser:
925 // simple user message just text content
926 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
927 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
928 if !ok {
929 warnings = append(warnings, ai.CallWarning{
930 Type: ai.CallWarningTypeOther,
931 Message: "user message text part does not have the right type",
932 })
933 continue
934 }
935 messages = append(messages, openai.UserMessage(textPart.Text))
936 continue
937 }
938 // text content and attachments
939 // for now we only support image content later we need to check
940 // TODO: add the supported media types to the language model so we
941 // can use that to validate the data here.
942 var content []openai.ChatCompletionContentPartUnionParam
943 for _, c := range msg.Content {
944 switch c.GetType() {
945 case ai.ContentTypeText:
946 textPart, ok := ai.AsContentType[ai.TextPart](c)
947 if !ok {
948 warnings = append(warnings, ai.CallWarning{
949 Type: ai.CallWarningTypeOther,
950 Message: "user message text part does not have the right type",
951 })
952 continue
953 }
954 content = append(content, openai.ChatCompletionContentPartUnionParam{
955 OfText: &openai.ChatCompletionContentPartTextParam{
956 Text: textPart.Text,
957 },
958 })
959 case ai.ContentTypeFile:
960 filePart, ok := ai.AsContentType[ai.FilePart](c)
961 if !ok {
962 warnings = append(warnings, ai.CallWarning{
963 Type: ai.CallWarningTypeOther,
964 Message: "user message file part does not have the right type",
965 })
966 continue
967 }
968
969 switch {
970 case strings.HasPrefix(filePart.MediaType, "image/"):
971 // Handle image files
972 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
973 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
974 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
975
976 // Check for provider-specific options like image detail
977 if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
978 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
979 imageURL.Detail = detail.ImageDetail
980 }
981 }
982
983 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
984 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
985
986 case filePart.MediaType == "audio/wav":
987 // Handle WAV audio files
988 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
989 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
990 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
991 Data: base64Encoded,
992 Format: "wav",
993 },
994 }
995 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
996
997 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
998 // Handle MP3 audio files
999 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1000 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
1001 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
1002 Data: base64Encoded,
1003 Format: "mp3",
1004 },
1005 }
1006 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
1007
1008 case filePart.MediaType == "application/pdf":
1009 // Handle PDF files
1010 dataStr := string(filePart.Data)
1011
1012 // Check if data looks like a file ID (starts with "file-")
1013 if strings.HasPrefix(dataStr, "file-") {
1014 fileBlock := openai.ChatCompletionContentPartFileParam{
1015 File: openai.ChatCompletionContentPartFileFileParam{
1016 FileID: param.NewOpt(dataStr),
1017 },
1018 }
1019 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1020 } else {
1021 // Handle as base64 data
1022 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1023 data := "data:application/pdf;base64," + base64Encoded
1024
1025 filename := filePart.Filename
1026 if filename == "" {
1027 // Generate default filename based on content index
1028 filename = fmt.Sprintf("part-%d.pdf", len(content))
1029 }
1030
1031 fileBlock := openai.ChatCompletionContentPartFileParam{
1032 File: openai.ChatCompletionContentPartFileFileParam{
1033 Filename: param.NewOpt(filename),
1034 FileData: param.NewOpt(data),
1035 },
1036 }
1037 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1038 }
1039
1040 default:
1041 warnings = append(warnings, ai.CallWarning{
1042 Type: ai.CallWarningTypeOther,
1043 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1044 })
1045 }
1046 }
1047 }
1048 messages = append(messages, openai.UserMessage(content))
1049 case ai.MessageRoleAssistant:
1050 // simple assistant message just text content
1051 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1052 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1053 if !ok {
1054 warnings = append(warnings, ai.CallWarning{
1055 Type: ai.CallWarningTypeOther,
1056 Message: "assistant message text part does not have the right type",
1057 })
1058 continue
1059 }
1060 messages = append(messages, openai.AssistantMessage(textPart.Text))
1061 continue
1062 }
1063 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1064 Role: "assistant",
1065 }
1066 for _, c := range msg.Content {
1067 switch c.GetType() {
1068 case ai.ContentTypeText:
1069 textPart, ok := ai.AsContentType[ai.TextPart](c)
1070 if !ok {
1071 warnings = append(warnings, ai.CallWarning{
1072 Type: ai.CallWarningTypeOther,
1073 Message: "assistant message text part does not have the right type",
1074 })
1075 continue
1076 }
1077 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1078 OfString: param.NewOpt(textPart.Text),
1079 }
1080 case ai.ContentTypeToolCall:
1081 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1082 if !ok {
1083 warnings = append(warnings, ai.CallWarning{
1084 Type: ai.CallWarningTypeOther,
1085 Message: "assistant message tool part does not have the right type",
1086 })
1087 continue
1088 }
1089 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1090 openai.ChatCompletionMessageToolCallUnionParam{
1091 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1092 ID: toolCallPart.ToolCallID,
1093 Type: "function",
1094 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1095 Name: toolCallPart.ToolName,
1096 Arguments: toolCallPart.Input,
1097 },
1098 },
1099 })
1100 }
1101 }
1102 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1103 OfAssistant: &assistantMsg,
1104 })
1105 case ai.MessageRoleTool:
1106 for _, c := range msg.Content {
1107 if c.GetType() != ai.ContentTypeToolResult {
1108 warnings = append(warnings, ai.CallWarning{
1109 Type: ai.CallWarningTypeOther,
1110 Message: "tool message can only have tool result content",
1111 })
1112 continue
1113 }
1114
1115 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1116 if !ok {
1117 warnings = append(warnings, ai.CallWarning{
1118 Type: ai.CallWarningTypeOther,
1119 Message: "tool message result part does not have the right type",
1120 })
1121 continue
1122 }
1123
1124 switch toolResultPart.Output.GetType() {
1125 case ai.ToolResultContentTypeText:
1126 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1127 if !ok {
1128 warnings = append(warnings, ai.CallWarning{
1129 Type: ai.CallWarningTypeOther,
1130 Message: "tool result output does not have the right type",
1131 })
1132 continue
1133 }
1134 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1135 case ai.ToolResultContentTypeError:
1136 // TODO: check if better handling is needed
1137 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1138 if !ok {
1139 warnings = append(warnings, ai.CallWarning{
1140 Type: ai.CallWarningTypeOther,
1141 Message: "tool result output does not have the right type",
1142 })
1143 continue
1144 }
1145 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1146 }
1147 }
1148 }
1149 }
1150 return messages, warnings
1151}
1152
1153// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1154func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1155 var annotations []openai.ChatCompletionMessageAnnotation
1156
1157 // Parse the raw JSON to extract annotations
1158 var deltaData map[string]any
1159 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1160 return annotations
1161 }
1162
1163 // Check if annotations exist in the delta
1164 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1165 for _, annotationData := range annotationsData {
1166 if annotationMap, ok := annotationData.(map[string]any); ok {
1167 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1168 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1169 annotation := openai.ChatCompletionMessageAnnotation{
1170 Type: "url_citation",
1171 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1172 URL: urlCitationData["url"].(string),
1173 Title: urlCitationData["title"].(string),
1174 },
1175 }
1176 annotations = append(annotations, annotation)
1177 }
1178 }
1179 }
1180 }
1181 }
1182
1183 return annotations
1184}