1package openai
2
3import (
4 "cmp"
5 "context"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "maps"
12 "strings"
13
14 "github.com/charmbracelet/fantasy/ai"
15 xjson "github.com/charmbracelet/x/json"
16 "github.com/google/uuid"
17 "github.com/openai/openai-go/v2"
18 "github.com/openai/openai-go/v2/option"
19 "github.com/openai/openai-go/v2/packages/param"
20 "github.com/openai/openai-go/v2/shared"
21)
22
23const (
24 Name = "openai"
25 DefaultURL = "https://api.openai.com/v1"
26)
27
28type provider struct {
29 options options
30}
31
32type PrepareCallWithOptions = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
33
34type Hooks struct {
35 PrepareCallWithOptions PrepareCallWithOptions
36}
37
38type options struct {
39 baseURL string
40 apiKey string
41 organization string
42 project string
43 name string
44 hooks Hooks
45 headers map[string]string
46 client option.HTTPClient
47}
48
49type Option = func(*options)
50
51func New(opts ...Option) ai.Provider {
52 providerOptions := options{
53 headers: map[string]string{},
54 }
55 for _, o := range opts {
56 o(&providerOptions)
57 }
58
59 providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL)
60 providerOptions.name = cmp.Or(providerOptions.name, Name)
61
62 if providerOptions.organization != "" {
63 providerOptions.headers["OpenAi-Organization"] = providerOptions.organization
64 }
65 if providerOptions.project != "" {
66 providerOptions.headers["OpenAi-Project"] = providerOptions.project
67 }
68
69 return &provider{options: providerOptions}
70}
71
72func WithBaseURL(baseURL string) Option {
73 return func(o *options) {
74 o.baseURL = baseURL
75 }
76}
77
78func WithAPIKey(apiKey string) Option {
79 return func(o *options) {
80 o.apiKey = apiKey
81 }
82}
83
84func WithOrganization(organization string) Option {
85 return func(o *options) {
86 o.organization = organization
87 }
88}
89
90func WithProject(project string) Option {
91 return func(o *options) {
92 o.project = project
93 }
94}
95
96func WithName(name string) Option {
97 return func(o *options) {
98 o.name = name
99 }
100}
101
102func WithHeaders(headers map[string]string) Option {
103 return func(o *options) {
104 maps.Copy(o.headers, headers)
105 }
106}
107
108func WithHTTPClient(client option.HTTPClient) Option {
109 return func(o *options) {
110 o.client = client
111 }
112}
113
114func WithHooks(hooks Hooks) Option {
115 return func(o *options) {
116 o.hooks = hooks
117 }
118}
119
120// LanguageModel implements ai.Provider.
121func (o *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
122 openaiClientOptions := []option.RequestOption{}
123 if o.options.apiKey != "" {
124 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
125 }
126 if o.options.baseURL != "" {
127 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
128 }
129
130 for key, value := range o.options.headers {
131 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
132 }
133
134 if o.options.client != nil {
135 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
136 }
137
138 return languageModel{
139 modelID: modelID,
140 provider: o.options.name,
141 options: o.options,
142 client: openai.NewClient(openaiClientOptions...),
143 }, nil
144}
145
146type languageModel struct {
147 provider string
148 modelID string
149 client openai.Client
150 options options
151}
152
153// Model implements ai.LanguageModel.
154func (o languageModel) Model() string {
155 return o.modelID
156}
157
158// Provider implements ai.LanguageModel.
159func (o languageModel) Provider() string {
160 return o.provider
161}
162
163func prepareCallWithOptions(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
164 if call.ProviderOptions == nil {
165 return nil, nil
166 }
167 var warnings []ai.CallWarning
168 providerOptions := &ProviderOptions{}
169 if v, ok := call.ProviderOptions[Name]; ok {
170 providerOptions, ok = v.(*ProviderOptions)
171 if !ok {
172 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
173 }
174 }
175
176 if providerOptions.LogitBias != nil {
177 params.LogitBias = providerOptions.LogitBias
178 }
179 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
180 providerOptions.LogProbs = nil
181 }
182 if providerOptions.LogProbs != nil {
183 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
184 }
185 if providerOptions.TopLogProbs != nil {
186 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
187 }
188 if providerOptions.User != nil {
189 params.User = param.NewOpt(*providerOptions.User)
190 }
191 if providerOptions.ParallelToolCalls != nil {
192 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
193 }
194 if providerOptions.MaxCompletionTokens != nil {
195 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
196 }
197
198 if providerOptions.TextVerbosity != nil {
199 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
200 }
201 if providerOptions.Prediction != nil {
202 // Convert map[string]any to ChatCompletionPredictionContentParam
203 if content, ok := providerOptions.Prediction["content"]; ok {
204 if contentStr, ok := content.(string); ok {
205 params.Prediction = openai.ChatCompletionPredictionContentParam{
206 Content: openai.ChatCompletionPredictionContentContentUnionParam{
207 OfString: param.NewOpt(contentStr),
208 },
209 }
210 }
211 }
212 }
213 if providerOptions.Store != nil {
214 params.Store = param.NewOpt(*providerOptions.Store)
215 }
216 if providerOptions.Metadata != nil {
217 // Convert map[string]any to map[string]string
218 metadata := make(map[string]string)
219 for k, v := range providerOptions.Metadata {
220 if str, ok := v.(string); ok {
221 metadata[k] = str
222 }
223 }
224 params.Metadata = metadata
225 }
226 if providerOptions.PromptCacheKey != nil {
227 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
228 }
229 if providerOptions.SafetyIdentifier != nil {
230 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
231 }
232 if providerOptions.ServiceTier != nil {
233 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
234 }
235
236 if providerOptions.ReasoningEffort != nil {
237 switch *providerOptions.ReasoningEffort {
238 case ReasoningEffortMinimal:
239 params.ReasoningEffort = shared.ReasoningEffortMinimal
240 case ReasoningEffortLow:
241 params.ReasoningEffort = shared.ReasoningEffortLow
242 case ReasoningEffortMedium:
243 params.ReasoningEffort = shared.ReasoningEffortMedium
244 case ReasoningEffortHigh:
245 params.ReasoningEffort = shared.ReasoningEffortHigh
246 default:
247 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
248 }
249 }
250
251 if isReasoningModel(model.Model()) {
252 if providerOptions.LogitBias != nil {
253 params.LogitBias = nil
254 warnings = append(warnings, ai.CallWarning{
255 Type: ai.CallWarningTypeUnsupportedSetting,
256 Setting: "LogitBias",
257 Message: "LogitBias is not supported for reasoning models",
258 })
259 }
260 if providerOptions.LogProbs != nil {
261 params.Logprobs = param.Opt[bool]{}
262 warnings = append(warnings, ai.CallWarning{
263 Type: ai.CallWarningTypeUnsupportedSetting,
264 Setting: "Logprobs",
265 Message: "Logprobs is not supported for reasoning models",
266 })
267 }
268 if providerOptions.TopLogProbs != nil {
269 params.TopLogprobs = param.Opt[int64]{}
270 warnings = append(warnings, ai.CallWarning{
271 Type: ai.CallWarningTypeUnsupportedSetting,
272 Setting: "TopLogprobs",
273 Message: "TopLogprobs is not supported for reasoning models",
274 })
275 }
276 }
277
278 // Handle service tier validation
279 if providerOptions.ServiceTier != nil {
280 serviceTier := *providerOptions.ServiceTier
281 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
282 params.ServiceTier = ""
283 warnings = append(warnings, ai.CallWarning{
284 Type: ai.CallWarningTypeUnsupportedSetting,
285 Setting: "ServiceTier",
286 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
287 })
288 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
289 params.ServiceTier = ""
290 warnings = append(warnings, ai.CallWarning{
291 Type: ai.CallWarningTypeUnsupportedSetting,
292 Setting: "ServiceTier",
293 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
294 })
295 }
296 }
297 return warnings, nil
298}
299
300func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
301 params := &openai.ChatCompletionNewParams{}
302 messages, warnings := toPrompt(call.Prompt)
303 if call.TopK != nil {
304 warnings = append(warnings, ai.CallWarning{
305 Type: ai.CallWarningTypeUnsupportedSetting,
306 Setting: "top_k",
307 })
308 }
309 params.Messages = messages
310 params.Model = o.modelID
311
312 if call.MaxOutputTokens != nil {
313 params.MaxTokens = param.NewOpt(*call.MaxOutputTokens)
314 }
315 if call.Temperature != nil {
316 params.Temperature = param.NewOpt(*call.Temperature)
317 }
318 if call.TopP != nil {
319 params.TopP = param.NewOpt(*call.TopP)
320 }
321 if call.FrequencyPenalty != nil {
322 params.FrequencyPenalty = param.NewOpt(*call.FrequencyPenalty)
323 }
324 if call.PresencePenalty != nil {
325 params.PresencePenalty = param.NewOpt(*call.PresencePenalty)
326 }
327
328 if isReasoningModel(o.modelID) {
329 // remove unsupported settings for reasoning models
330 // see https://platform.openai.com/docs/guides/reasoning#limitations
331 if call.Temperature != nil {
332 params.Temperature = param.Opt[float64]{}
333 warnings = append(warnings, ai.CallWarning{
334 Type: ai.CallWarningTypeUnsupportedSetting,
335 Setting: "temperature",
336 Details: "temperature is not supported for reasoning models",
337 })
338 }
339 if call.TopP != nil {
340 params.TopP = param.Opt[float64]{}
341 warnings = append(warnings, ai.CallWarning{
342 Type: ai.CallWarningTypeUnsupportedSetting,
343 Setting: "TopP",
344 Details: "TopP is not supported for reasoning models",
345 })
346 }
347 if call.FrequencyPenalty != nil {
348 params.FrequencyPenalty = param.Opt[float64]{}
349 warnings = append(warnings, ai.CallWarning{
350 Type: ai.CallWarningTypeUnsupportedSetting,
351 Setting: "FrequencyPenalty",
352 Details: "FrequencyPenalty is not supported for reasoning models",
353 })
354 }
355 if call.PresencePenalty != nil {
356 params.PresencePenalty = param.Opt[float64]{}
357 warnings = append(warnings, ai.CallWarning{
358 Type: ai.CallWarningTypeUnsupportedSetting,
359 Setting: "PresencePenalty",
360 Details: "PresencePenalty is not supported for reasoning models",
361 })
362 }
363
364 // reasoning models use max_completion_tokens instead of max_tokens
365 if call.MaxOutputTokens != nil {
366 if !params.MaxCompletionTokens.Valid() {
367 params.MaxCompletionTokens = param.NewOpt(*call.MaxOutputTokens)
368 }
369 params.MaxTokens = param.Opt[int64]{}
370 }
371 }
372
373 // Handle search preview models
374 if isSearchPreviewModel(o.modelID) {
375 if call.Temperature != nil {
376 params.Temperature = param.Opt[float64]{}
377 warnings = append(warnings, ai.CallWarning{
378 Type: ai.CallWarningTypeUnsupportedSetting,
379 Setting: "temperature",
380 Details: "temperature is not supported for the search preview models and has been removed.",
381 })
382 }
383 }
384
385 prepareOptions := prepareCallWithOptions
386 if o.options.hooks.PrepareCallWithOptions != nil {
387 prepareOptions = o.options.hooks.PrepareCallWithOptions
388 }
389
390 optionsWarnings, err := prepareOptions(o, params, call)
391 if err != nil {
392 return nil, nil, err
393 }
394
395 if len(optionsWarnings) > 0 {
396 warnings = append(warnings, optionsWarnings...)
397 }
398
399 if len(call.Tools) > 0 {
400 tools, toolChoice, toolWarnings := toOpenAiTools(call.Tools, call.ToolChoice)
401 params.Tools = tools
402 if toolChoice != nil {
403 params.ToolChoice = *toolChoice
404 }
405 warnings = append(warnings, toolWarnings...)
406 }
407 return params, warnings, nil
408}
409
410func (o languageModel) handleError(err error) error {
411 var apiErr *openai.Error
412 if errors.As(err, &apiErr) {
413 requestDump := apiErr.DumpRequest(true)
414 responseDump := apiErr.DumpResponse(true)
415 headers := map[string]string{}
416 for k, h := range apiErr.Response.Header {
417 v := h[len(h)-1]
418 headers[strings.ToLower(k)] = v
419 }
420 return ai.NewAPICallError(
421 apiErr.Message,
422 apiErr.Request.URL.String(),
423 string(requestDump),
424 apiErr.StatusCode,
425 headers,
426 string(responseDump),
427 apiErr,
428 false,
429 )
430 }
431 return err
432}
433
434// Generate implements ai.LanguageModel.
435func (o languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
436 params, warnings, err := o.prepareParams(call)
437 if err != nil {
438 return nil, err
439 }
440 response, err := o.client.Chat.Completions.New(ctx, *params)
441 if err != nil {
442 return nil, o.handleError(err)
443 }
444
445 if len(response.Choices) == 0 {
446 return nil, errors.New("no response generated")
447 }
448 choice := response.Choices[0]
449 content := make([]ai.Content, 0, 1+len(choice.Message.ToolCalls)+len(choice.Message.Annotations))
450 text := choice.Message.Content
451 if text != "" {
452 content = append(content, ai.TextContent{
453 Text: text,
454 })
455 }
456
457 for _, tc := range choice.Message.ToolCalls {
458 toolCallID := tc.ID
459 if toolCallID == "" {
460 toolCallID = uuid.NewString()
461 }
462 content = append(content, ai.ToolCallContent{
463 ProviderExecuted: false, // TODO: update when handling other tools
464 ToolCallID: toolCallID,
465 ToolName: tc.Function.Name,
466 Input: tc.Function.Arguments,
467 })
468 }
469 // Handle annotations/citations
470 for _, annotation := range choice.Message.Annotations {
471 if annotation.Type == "url_citation" {
472 content = append(content, ai.SourceContent{
473 SourceType: ai.SourceTypeURL,
474 ID: uuid.NewString(),
475 URL: annotation.URLCitation.URL,
476 Title: annotation.URLCitation.Title,
477 })
478 }
479 }
480
481 completionTokenDetails := response.Usage.CompletionTokensDetails
482 promptTokenDetails := response.Usage.PromptTokensDetails
483
484 // Build provider metadata
485 providerMetadata := &ProviderMetadata{}
486 // Add logprobs if available
487 if len(choice.Logprobs.Content) > 0 {
488 providerMetadata.Logprobs = choice.Logprobs.Content
489 }
490
491 // Add prediction tokens if available
492 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
493 if completionTokenDetails.AcceptedPredictionTokens > 0 {
494 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
495 }
496 if completionTokenDetails.RejectedPredictionTokens > 0 {
497 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
498 }
499 }
500
501 return &ai.Response{
502 Content: content,
503 Usage: ai.Usage{
504 InputTokens: response.Usage.PromptTokens,
505 OutputTokens: response.Usage.CompletionTokens,
506 TotalTokens: response.Usage.TotalTokens,
507 ReasoningTokens: completionTokenDetails.ReasoningTokens,
508 CacheReadTokens: promptTokenDetails.CachedTokens,
509 },
510 FinishReason: mapOpenAiFinishReason(choice.FinishReason),
511 ProviderMetadata: ai.ProviderMetadata{
512 Name: providerMetadata,
513 },
514 Warnings: warnings,
515 }, nil
516}
517
518type toolCall struct {
519 id string
520 name string
521 arguments string
522 hasFinished bool
523}
524
525// Stream implements ai.LanguageModel.
526func (o languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
527 params, warnings, err := o.prepareParams(call)
528 if err != nil {
529 return nil, err
530 }
531
532 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
533 IncludeUsage: openai.Bool(true),
534 }
535
536 stream := o.client.Chat.Completions.NewStreaming(ctx, *params)
537 isActiveText := false
538 toolCalls := make(map[int64]toolCall)
539
540 // Build provider metadata for streaming
541 streamProviderMetadata := &ProviderMetadata{}
542 acc := openai.ChatCompletionAccumulator{}
543 var usage ai.Usage
544 return func(yield func(ai.StreamPart) bool) {
545 if len(warnings) > 0 {
546 if !yield(ai.StreamPart{
547 Type: ai.StreamPartTypeWarnings,
548 Warnings: warnings,
549 }) {
550 return
551 }
552 }
553 for stream.Next() {
554 chunk := stream.Current()
555 acc.AddChunk(chunk)
556 if chunk.Usage.TotalTokens > 0 {
557 // we do this here because the acc does not add prompt details
558 completionTokenDetails := chunk.Usage.CompletionTokensDetails
559 promptTokenDetails := chunk.Usage.PromptTokensDetails
560 usage = ai.Usage{
561 InputTokens: chunk.Usage.PromptTokens,
562 OutputTokens: chunk.Usage.CompletionTokens,
563 TotalTokens: chunk.Usage.TotalTokens,
564 ReasoningTokens: completionTokenDetails.ReasoningTokens,
565 CacheReadTokens: promptTokenDetails.CachedTokens,
566 }
567
568 // Add prediction tokens if available
569 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
570 if completionTokenDetails.AcceptedPredictionTokens > 0 {
571 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
572 }
573 if completionTokenDetails.RejectedPredictionTokens > 0 {
574 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
575 }
576 }
577 }
578 if len(chunk.Choices) == 0 {
579 continue
580 }
581 for _, choice := range chunk.Choices {
582 switch {
583 case choice.Delta.Content != "":
584 if !isActiveText {
585 isActiveText = true
586 if !yield(ai.StreamPart{
587 Type: ai.StreamPartTypeTextStart,
588 ID: "0",
589 }) {
590 return
591 }
592 }
593 if !yield(ai.StreamPart{
594 Type: ai.StreamPartTypeTextDelta,
595 ID: "0",
596 Delta: choice.Delta.Content,
597 }) {
598 return
599 }
600 case len(choice.Delta.ToolCalls) > 0:
601 if isActiveText {
602 isActiveText = false
603 if !yield(ai.StreamPart{
604 Type: ai.StreamPartTypeTextEnd,
605 ID: "0",
606 }) {
607 return
608 }
609 }
610
611 for _, toolCallDelta := range choice.Delta.ToolCalls {
612 if existingToolCall, ok := toolCalls[toolCallDelta.Index]; ok {
613 if existingToolCall.hasFinished {
614 continue
615 }
616 if toolCallDelta.Function.Arguments != "" {
617 existingToolCall.arguments += toolCallDelta.Function.Arguments
618 }
619 if !yield(ai.StreamPart{
620 Type: ai.StreamPartTypeToolInputDelta,
621 ID: existingToolCall.id,
622 Delta: toolCallDelta.Function.Arguments,
623 }) {
624 return
625 }
626 toolCalls[toolCallDelta.Index] = existingToolCall
627 if xjson.IsValid(existingToolCall.arguments) {
628 if !yield(ai.StreamPart{
629 Type: ai.StreamPartTypeToolInputEnd,
630 ID: existingToolCall.id,
631 }) {
632 return
633 }
634
635 if !yield(ai.StreamPart{
636 Type: ai.StreamPartTypeToolCall,
637 ID: existingToolCall.id,
638 ToolCallName: existingToolCall.name,
639 ToolCallInput: existingToolCall.arguments,
640 }) {
641 return
642 }
643 existingToolCall.hasFinished = true
644 toolCalls[toolCallDelta.Index] = existingToolCall
645 }
646 } else {
647 // Does not exist
648 var err error
649 if toolCallDelta.Type != "function" {
650 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function' type.")
651 }
652 if toolCallDelta.ID == "" {
653 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'id' to be a string.")
654 }
655 if toolCallDelta.Function.Name == "" {
656 err = ai.NewInvalidResponseDataError(toolCallDelta, "Expected 'function.name' to be a string.")
657 }
658 if err != nil {
659 yield(ai.StreamPart{
660 Type: ai.StreamPartTypeError,
661 Error: o.handleError(stream.Err()),
662 })
663 return
664 }
665
666 if !yield(ai.StreamPart{
667 Type: ai.StreamPartTypeToolInputStart,
668 ID: toolCallDelta.ID,
669 ToolCallName: toolCallDelta.Function.Name,
670 }) {
671 return
672 }
673 toolCalls[toolCallDelta.Index] = toolCall{
674 id: toolCallDelta.ID,
675 name: toolCallDelta.Function.Name,
676 arguments: toolCallDelta.Function.Arguments,
677 }
678
679 exTc := toolCalls[toolCallDelta.Index]
680 if exTc.arguments != "" {
681 if !yield(ai.StreamPart{
682 Type: ai.StreamPartTypeToolInputDelta,
683 ID: exTc.id,
684 Delta: exTc.arguments,
685 }) {
686 return
687 }
688 if xjson.IsValid(toolCalls[toolCallDelta.Index].arguments) {
689 if !yield(ai.StreamPart{
690 Type: ai.StreamPartTypeToolInputEnd,
691 ID: toolCallDelta.ID,
692 }) {
693 return
694 }
695
696 if !yield(ai.StreamPart{
697 Type: ai.StreamPartTypeToolCall,
698 ID: exTc.id,
699 ToolCallName: exTc.name,
700 ToolCallInput: exTc.arguments,
701 }) {
702 return
703 }
704 exTc.hasFinished = true
705 toolCalls[toolCallDelta.Index] = exTc
706 }
707 }
708 continue
709 }
710 }
711 }
712 }
713
714 // Check for annotations in the delta's raw JSON
715 for _, choice := range chunk.Choices {
716 if annotations := parseAnnotationsFromDelta(choice.Delta); len(annotations) > 0 {
717 for _, annotation := range annotations {
718 if annotation.Type == "url_citation" {
719 if !yield(ai.StreamPart{
720 Type: ai.StreamPartTypeSource,
721 ID: uuid.NewString(),
722 SourceType: ai.SourceTypeURL,
723 URL: annotation.URLCitation.URL,
724 Title: annotation.URLCitation.Title,
725 }) {
726 return
727 }
728 }
729 }
730 }
731 }
732 }
733 err := stream.Err()
734 if err == nil || errors.Is(err, io.EOF) {
735 // finished
736 if isActiveText {
737 isActiveText = false
738 if !yield(ai.StreamPart{
739 Type: ai.StreamPartTypeTextEnd,
740 ID: "0",
741 }) {
742 return
743 }
744 }
745
746 // Add logprobs if available
747 if len(acc.Choices) > 0 && len(acc.Choices[0].Logprobs.Content) > 0 {
748 streamProviderMetadata.Logprobs = acc.Choices[0].Logprobs.Content
749 }
750
751 // Handle annotations/citations from accumulated response
752 if len(acc.Choices) > 0 {
753 for _, annotation := range acc.Choices[0].Message.Annotations {
754 if annotation.Type == "url_citation" {
755 if !yield(ai.StreamPart{
756 Type: ai.StreamPartTypeSource,
757 ID: acc.ID,
758 SourceType: ai.SourceTypeURL,
759 URL: annotation.URLCitation.URL,
760 Title: annotation.URLCitation.Title,
761 }) {
762 return
763 }
764 }
765 }
766 }
767
768 finishReason := mapOpenAiFinishReason(acc.Choices[0].FinishReason)
769 yield(ai.StreamPart{
770 Type: ai.StreamPartTypeFinish,
771 Usage: usage,
772 FinishReason: finishReason,
773 ProviderMetadata: ai.ProviderMetadata{
774 Name: streamProviderMetadata,
775 },
776 })
777 return
778 } else {
779 yield(ai.StreamPart{
780 Type: ai.StreamPartTypeError,
781 Error: o.handleError(err),
782 })
783 return
784 }
785 }, nil
786}
787
788func (o *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
789 var options ProviderOptions
790 if err := ai.ParseOptions(data, &options); err != nil {
791 return nil, err
792 }
793 return &options, nil
794}
795
796func (o *provider) Name() string {
797 return Name
798}
799
800func mapOpenAiFinishReason(finishReason string) ai.FinishReason {
801 switch finishReason {
802 case "stop":
803 return ai.FinishReasonStop
804 case "length":
805 return ai.FinishReasonLength
806 case "content_filter":
807 return ai.FinishReasonContentFilter
808 case "function_call", "tool_calls":
809 return ai.FinishReasonToolCalls
810 default:
811 return ai.FinishReasonUnknown
812 }
813}
814
815func isReasoningModel(modelID string) bool {
816 return strings.HasPrefix(modelID, "o") || strings.HasPrefix(modelID, "gpt-5") || strings.HasPrefix(modelID, "gpt-5-chat")
817}
818
819func isSearchPreviewModel(modelID string) bool {
820 return strings.Contains(modelID, "search-preview")
821}
822
823func supportsFlexProcessing(modelID string) bool {
824 return strings.HasPrefix(modelID, "o3") || strings.HasPrefix(modelID, "o4-mini") || strings.HasPrefix(modelID, "gpt-5")
825}
826
827func supportsPriorityProcessing(modelID string) bool {
828 return strings.HasPrefix(modelID, "gpt-4") || strings.HasPrefix(modelID, "gpt-5") ||
829 strings.HasPrefix(modelID, "gpt-5-mini") || strings.HasPrefix(modelID, "o3") ||
830 strings.HasPrefix(modelID, "o4-mini")
831}
832
833func toOpenAiTools(tools []ai.Tool, toolChoice *ai.ToolChoice) (openAiTools []openai.ChatCompletionToolUnionParam, openAiToolChoice *openai.ChatCompletionToolChoiceOptionUnionParam, warnings []ai.CallWarning) {
834 for _, tool := range tools {
835 if tool.GetType() == ai.ToolTypeFunction {
836 ft, ok := tool.(ai.FunctionTool)
837 if !ok {
838 continue
839 }
840 openAiTools = append(openAiTools, openai.ChatCompletionToolUnionParam{
841 OfFunction: &openai.ChatCompletionFunctionToolParam{
842 Function: shared.FunctionDefinitionParam{
843 Name: ft.Name,
844 Description: param.NewOpt(ft.Description),
845 Parameters: openai.FunctionParameters(ft.InputSchema),
846 Strict: param.NewOpt(false),
847 },
848 Type: "function",
849 },
850 })
851 continue
852 }
853
854 // TODO: handle provider tool calls
855 warnings = append(warnings, ai.CallWarning{
856 Type: ai.CallWarningTypeUnsupportedTool,
857 Tool: tool,
858 Message: "tool is not supported",
859 })
860 }
861 if toolChoice == nil {
862 return openAiTools, openAiToolChoice, warnings
863 }
864
865 switch *toolChoice {
866 case ai.ToolChoiceAuto:
867 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
868 OfAuto: param.NewOpt("auto"),
869 }
870 case ai.ToolChoiceNone:
871 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
872 OfAuto: param.NewOpt("none"),
873 }
874 default:
875 openAiToolChoice = &openai.ChatCompletionToolChoiceOptionUnionParam{
876 OfFunctionToolChoice: &openai.ChatCompletionNamedToolChoiceParam{
877 Type: "function",
878 Function: openai.ChatCompletionNamedToolChoiceFunctionParam{
879 Name: string(*toolChoice),
880 },
881 },
882 }
883 }
884 return openAiTools, openAiToolChoice, warnings
885}
886
887func toPrompt(prompt ai.Prompt) ([]openai.ChatCompletionMessageParamUnion, []ai.CallWarning) {
888 var messages []openai.ChatCompletionMessageParamUnion
889 var warnings []ai.CallWarning
890 for _, msg := range prompt {
891 switch msg.Role {
892 case ai.MessageRoleSystem:
893 var systemPromptParts []string
894 for _, c := range msg.Content {
895 if c.GetType() != ai.ContentTypeText {
896 warnings = append(warnings, ai.CallWarning{
897 Type: ai.CallWarningTypeOther,
898 Message: "system prompt can only have text content",
899 })
900 continue
901 }
902 textPart, ok := ai.AsContentType[ai.TextPart](c)
903 if !ok {
904 warnings = append(warnings, ai.CallWarning{
905 Type: ai.CallWarningTypeOther,
906 Message: "system prompt text part does not have the right type",
907 })
908 continue
909 }
910 text := textPart.Text
911 if strings.TrimSpace(text) != "" {
912 systemPromptParts = append(systemPromptParts, textPart.Text)
913 }
914 }
915 if len(systemPromptParts) == 0 {
916 warnings = append(warnings, ai.CallWarning{
917 Type: ai.CallWarningTypeOther,
918 Message: "system prompt has no text parts",
919 })
920 continue
921 }
922 messages = append(messages, openai.SystemMessage(strings.Join(systemPromptParts, "\n")))
923 case ai.MessageRoleUser:
924 // simple user message just text content
925 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
926 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
927 if !ok {
928 warnings = append(warnings, ai.CallWarning{
929 Type: ai.CallWarningTypeOther,
930 Message: "user message text part does not have the right type",
931 })
932 continue
933 }
934 messages = append(messages, openai.UserMessage(textPart.Text))
935 continue
936 }
937 // text content and attachments
938 // for now we only support image content later we need to check
939 // TODO: add the supported media types to the language model so we
940 // can use that to validate the data here.
941 var content []openai.ChatCompletionContentPartUnionParam
942 for _, c := range msg.Content {
943 switch c.GetType() {
944 case ai.ContentTypeText:
945 textPart, ok := ai.AsContentType[ai.TextPart](c)
946 if !ok {
947 warnings = append(warnings, ai.CallWarning{
948 Type: ai.CallWarningTypeOther,
949 Message: "user message text part does not have the right type",
950 })
951 continue
952 }
953 content = append(content, openai.ChatCompletionContentPartUnionParam{
954 OfText: &openai.ChatCompletionContentPartTextParam{
955 Text: textPart.Text,
956 },
957 })
958 case ai.ContentTypeFile:
959 filePart, ok := ai.AsContentType[ai.FilePart](c)
960 if !ok {
961 warnings = append(warnings, ai.CallWarning{
962 Type: ai.CallWarningTypeOther,
963 Message: "user message file part does not have the right type",
964 })
965 continue
966 }
967
968 switch {
969 case strings.HasPrefix(filePart.MediaType, "image/"):
970 // Handle image files
971 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
972 data := "data:" + filePart.MediaType + ";base64," + base64Encoded
973 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: data}
974
975 // Check for provider-specific options like image detail
976 if providerOptions, ok := filePart.ProviderOptions[Name]; ok {
977 if detail, ok := providerOptions.(*ProviderFileOptions); ok {
978 imageURL.Detail = detail.ImageDetail
979 }
980 }
981
982 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
983 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
984
985 case filePart.MediaType == "audio/wav":
986 // Handle WAV audio files
987 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
988 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
989 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
990 Data: base64Encoded,
991 Format: "wav",
992 },
993 }
994 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
995
996 case filePart.MediaType == "audio/mpeg" || filePart.MediaType == "audio/mp3":
997 // Handle MP3 audio files
998 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
999 audioBlock := openai.ChatCompletionContentPartInputAudioParam{
1000 InputAudio: openai.ChatCompletionContentPartInputAudioInputAudioParam{
1001 Data: base64Encoded,
1002 Format: "mp3",
1003 },
1004 }
1005 content = append(content, openai.ChatCompletionContentPartUnionParam{OfInputAudio: &audioBlock})
1006
1007 case filePart.MediaType == "application/pdf":
1008 // Handle PDF files
1009 dataStr := string(filePart.Data)
1010
1011 // Check if data looks like a file ID (starts with "file-")
1012 if strings.HasPrefix(dataStr, "file-") {
1013 fileBlock := openai.ChatCompletionContentPartFileParam{
1014 File: openai.ChatCompletionContentPartFileFileParam{
1015 FileID: param.NewOpt(dataStr),
1016 },
1017 }
1018 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1019 } else {
1020 // Handle as base64 data
1021 base64Encoded := base64.StdEncoding.EncodeToString(filePart.Data)
1022 data := "data:application/pdf;base64," + base64Encoded
1023
1024 filename := filePart.Filename
1025 if filename == "" {
1026 // Generate default filename based on content index
1027 filename = fmt.Sprintf("part-%d.pdf", len(content))
1028 }
1029
1030 fileBlock := openai.ChatCompletionContentPartFileParam{
1031 File: openai.ChatCompletionContentPartFileFileParam{
1032 Filename: param.NewOpt(filename),
1033 FileData: param.NewOpt(data),
1034 },
1035 }
1036 content = append(content, openai.ChatCompletionContentPartUnionParam{OfFile: &fileBlock})
1037 }
1038
1039 default:
1040 warnings = append(warnings, ai.CallWarning{
1041 Type: ai.CallWarningTypeOther,
1042 Message: fmt.Sprintf("file part media type %s not supported", filePart.MediaType),
1043 })
1044 }
1045 }
1046 }
1047 messages = append(messages, openai.UserMessage(content))
1048 case ai.MessageRoleAssistant:
1049 // simple assistant message just text content
1050 if len(msg.Content) == 1 && msg.Content[0].GetType() == ai.ContentTypeText {
1051 textPart, ok := ai.AsContentType[ai.TextPart](msg.Content[0])
1052 if !ok {
1053 warnings = append(warnings, ai.CallWarning{
1054 Type: ai.CallWarningTypeOther,
1055 Message: "assistant message text part does not have the right type",
1056 })
1057 continue
1058 }
1059 messages = append(messages, openai.AssistantMessage(textPart.Text))
1060 continue
1061 }
1062 assistantMsg := openai.ChatCompletionAssistantMessageParam{
1063 Role: "assistant",
1064 }
1065 for _, c := range msg.Content {
1066 switch c.GetType() {
1067 case ai.ContentTypeText:
1068 textPart, ok := ai.AsContentType[ai.TextPart](c)
1069 if !ok {
1070 warnings = append(warnings, ai.CallWarning{
1071 Type: ai.CallWarningTypeOther,
1072 Message: "assistant message text part does not have the right type",
1073 })
1074 continue
1075 }
1076 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
1077 OfString: param.NewOpt(textPart.Text),
1078 }
1079 case ai.ContentTypeToolCall:
1080 toolCallPart, ok := ai.AsContentType[ai.ToolCallPart](c)
1081 if !ok {
1082 warnings = append(warnings, ai.CallWarning{
1083 Type: ai.CallWarningTypeOther,
1084 Message: "assistant message tool part does not have the right type",
1085 })
1086 continue
1087 }
1088 assistantMsg.ToolCalls = append(assistantMsg.ToolCalls,
1089 openai.ChatCompletionMessageToolCallUnionParam{
1090 OfFunction: &openai.ChatCompletionMessageFunctionToolCallParam{
1091 ID: toolCallPart.ToolCallID,
1092 Type: "function",
1093 Function: openai.ChatCompletionMessageFunctionToolCallFunctionParam{
1094 Name: toolCallPart.ToolName,
1095 Arguments: toolCallPart.Input,
1096 },
1097 },
1098 })
1099 }
1100 }
1101 messages = append(messages, openai.ChatCompletionMessageParamUnion{
1102 OfAssistant: &assistantMsg,
1103 })
1104 case ai.MessageRoleTool:
1105 for _, c := range msg.Content {
1106 if c.GetType() != ai.ContentTypeToolResult {
1107 warnings = append(warnings, ai.CallWarning{
1108 Type: ai.CallWarningTypeOther,
1109 Message: "tool message can only have tool result content",
1110 })
1111 continue
1112 }
1113
1114 toolResultPart, ok := ai.AsContentType[ai.ToolResultPart](c)
1115 if !ok {
1116 warnings = append(warnings, ai.CallWarning{
1117 Type: ai.CallWarningTypeOther,
1118 Message: "tool message result part does not have the right type",
1119 })
1120 continue
1121 }
1122
1123 switch toolResultPart.Output.GetType() {
1124 case ai.ToolResultContentTypeText:
1125 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](toolResultPart.Output)
1126 if !ok {
1127 warnings = append(warnings, ai.CallWarning{
1128 Type: ai.CallWarningTypeOther,
1129 Message: "tool result output does not have the right type",
1130 })
1131 continue
1132 }
1133 messages = append(messages, openai.ToolMessage(output.Text, toolResultPart.ToolCallID))
1134 case ai.ToolResultContentTypeError:
1135 // TODO: check if better handling is needed
1136 output, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](toolResultPart.Output)
1137 if !ok {
1138 warnings = append(warnings, ai.CallWarning{
1139 Type: ai.CallWarningTypeOther,
1140 Message: "tool result output does not have the right type",
1141 })
1142 continue
1143 }
1144 messages = append(messages, openai.ToolMessage(output.Error.Error(), toolResultPart.ToolCallID))
1145 }
1146 }
1147 }
1148 }
1149 return messages, warnings
1150}
1151
1152// parseAnnotationsFromDelta parses annotations from the raw JSON of a delta.
1153func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []openai.ChatCompletionMessageAnnotation {
1154 var annotations []openai.ChatCompletionMessageAnnotation
1155
1156 // Parse the raw JSON to extract annotations
1157 var deltaData map[string]any
1158 if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
1159 return annotations
1160 }
1161
1162 // Check if annotations exist in the delta
1163 if annotationsData, ok := deltaData["annotations"].([]any); ok {
1164 for _, annotationData := range annotationsData {
1165 if annotationMap, ok := annotationData.(map[string]any); ok {
1166 if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
1167 if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
1168 annotation := openai.ChatCompletionMessageAnnotation{
1169 Type: "url_citation",
1170 URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{
1171 URL: urlCitationData["url"].(string),
1172 Title: urlCitationData["title"].(string),
1173 },
1174 }
1175 annotations = append(annotations, annotation)
1176 }
1177 }
1178 }
1179 }
1180 }
1181
1182 return annotations
1183}