1package openai
2
3import (
4 "fmt"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/openai/openai-go/v2"
8 "github.com/openai/openai-go/v2/packages/param"
9 "github.com/openai/openai-go/v2/shared"
10)
11
12type (
13 LanguageModelPrepareCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
14 LanguageModelMapFinishReasonFunc = func(choice openai.ChatCompletionChoice) ai.FinishReason
15 LanguageModelUsageFunc = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
16 LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []ai.Content
17 LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
18 LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
19 LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
20)
21
22func defaultPrepareLanguageModelCall(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
23 if call.ProviderOptions == nil {
24 return nil, nil
25 }
26 var warnings []ai.CallWarning
27 providerOptions := &ProviderOptions{}
28 if v, ok := call.ProviderOptions[Name]; ok {
29 providerOptions, ok = v.(*ProviderOptions)
30 if !ok {
31 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
32 }
33 }
34
35 if providerOptions.LogitBias != nil {
36 params.LogitBias = providerOptions.LogitBias
37 }
38 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
39 providerOptions.LogProbs = nil
40 }
41 if providerOptions.LogProbs != nil {
42 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
43 }
44 if providerOptions.TopLogProbs != nil {
45 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
46 }
47 if providerOptions.User != nil {
48 params.User = param.NewOpt(*providerOptions.User)
49 }
50 if providerOptions.ParallelToolCalls != nil {
51 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
52 }
53 if providerOptions.MaxCompletionTokens != nil {
54 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
55 }
56
57 if providerOptions.TextVerbosity != nil {
58 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
59 }
60 if providerOptions.Prediction != nil {
61 // Convert map[string]any to ChatCompletionPredictionContentParam
62 if content, ok := providerOptions.Prediction["content"]; ok {
63 if contentStr, ok := content.(string); ok {
64 params.Prediction = openai.ChatCompletionPredictionContentParam{
65 Content: openai.ChatCompletionPredictionContentContentUnionParam{
66 OfString: param.NewOpt(contentStr),
67 },
68 }
69 }
70 }
71 }
72 if providerOptions.Store != nil {
73 params.Store = param.NewOpt(*providerOptions.Store)
74 }
75 if providerOptions.Metadata != nil {
76 // Convert map[string]any to map[string]string
77 metadata := make(map[string]string)
78 for k, v := range providerOptions.Metadata {
79 if str, ok := v.(string); ok {
80 metadata[k] = str
81 }
82 }
83 params.Metadata = metadata
84 }
85 if providerOptions.PromptCacheKey != nil {
86 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
87 }
88 if providerOptions.SafetyIdentifier != nil {
89 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
90 }
91 if providerOptions.ServiceTier != nil {
92 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
93 }
94
95 if providerOptions.ReasoningEffort != nil {
96 switch *providerOptions.ReasoningEffort {
97 case ReasoningEffortMinimal:
98 params.ReasoningEffort = shared.ReasoningEffortMinimal
99 case ReasoningEffortLow:
100 params.ReasoningEffort = shared.ReasoningEffortLow
101 case ReasoningEffortMedium:
102 params.ReasoningEffort = shared.ReasoningEffortMedium
103 case ReasoningEffortHigh:
104 params.ReasoningEffort = shared.ReasoningEffortHigh
105 default:
106 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
107 }
108 }
109
110 if isReasoningModel(model.Model()) {
111 if providerOptions.LogitBias != nil {
112 params.LogitBias = nil
113 warnings = append(warnings, ai.CallWarning{
114 Type: ai.CallWarningTypeUnsupportedSetting,
115 Setting: "LogitBias",
116 Message: "LogitBias is not supported for reasoning models",
117 })
118 }
119 if providerOptions.LogProbs != nil {
120 params.Logprobs = param.Opt[bool]{}
121 warnings = append(warnings, ai.CallWarning{
122 Type: ai.CallWarningTypeUnsupportedSetting,
123 Setting: "Logprobs",
124 Message: "Logprobs is not supported for reasoning models",
125 })
126 }
127 if providerOptions.TopLogProbs != nil {
128 params.TopLogprobs = param.Opt[int64]{}
129 warnings = append(warnings, ai.CallWarning{
130 Type: ai.CallWarningTypeUnsupportedSetting,
131 Setting: "TopLogprobs",
132 Message: "TopLogprobs is not supported for reasoning models",
133 })
134 }
135 }
136
137 // Handle service tier validation
138 if providerOptions.ServiceTier != nil {
139 serviceTier := *providerOptions.ServiceTier
140 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
141 params.ServiceTier = ""
142 warnings = append(warnings, ai.CallWarning{
143 Type: ai.CallWarningTypeUnsupportedSetting,
144 Setting: "ServiceTier",
145 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
146 })
147 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
148 params.ServiceTier = ""
149 warnings = append(warnings, ai.CallWarning{
150 Type: ai.CallWarningTypeUnsupportedSetting,
151 Setting: "ServiceTier",
152 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
153 })
154 }
155 }
156 return warnings, nil
157}
158
159func defaultMapFinishReason(choice openai.ChatCompletionChoice) ai.FinishReason {
160 finishReason := choice.FinishReason
161 switch finishReason {
162 case "stop":
163 return ai.FinishReasonStop
164 case "length":
165 return ai.FinishReasonLength
166 case "content_filter":
167 return ai.FinishReasonContentFilter
168 case "function_call", "tool_calls":
169 return ai.FinishReasonToolCalls
170 default:
171 return ai.FinishReasonUnknown
172 }
173}
174
175func defaultUsage(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
176 if len(response.Choices) == 0 {
177 return ai.Usage{}, nil
178 }
179 choice := response.Choices[0]
180 completionTokenDetails := response.Usage.CompletionTokensDetails
181 promptTokenDetails := response.Usage.PromptTokensDetails
182
183 // Build provider metadata
184 providerMetadata := &ProviderMetadata{}
185 // Add logprobs if available
186 if len(choice.Logprobs.Content) > 0 {
187 providerMetadata.Logprobs = choice.Logprobs.Content
188 }
189
190 // Add prediction tokens if available
191 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
192 if completionTokenDetails.AcceptedPredictionTokens > 0 {
193 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
194 }
195 if completionTokenDetails.RejectedPredictionTokens > 0 {
196 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
197 }
198 }
199 return ai.Usage{
200 InputTokens: response.Usage.PromptTokens,
201 OutputTokens: response.Usage.CompletionTokens,
202 TotalTokens: response.Usage.TotalTokens,
203 ReasoningTokens: completionTokenDetails.ReasoningTokens,
204 CacheReadTokens: promptTokenDetails.CachedTokens,
205 }, providerMetadata
206}
207
208func defaultStreamUsage(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
209 if chunk.Usage.TotalTokens == 0 {
210 return ai.Usage{}, nil
211 }
212 streamProviderMetadata := &ProviderMetadata{}
213 if metadata != nil {
214 if providerMetadata, ok := metadata[Name]; ok {
215 converted, ok := providerMetadata.(*ProviderMetadata)
216 if ok {
217 streamProviderMetadata = converted
218 }
219 }
220 }
221 // we do this here because the acc does not add prompt details
222 completionTokenDetails := chunk.Usage.CompletionTokensDetails
223 promptTokenDetails := chunk.Usage.PromptTokensDetails
224 usage := ai.Usage{
225 InputTokens: chunk.Usage.PromptTokens,
226 OutputTokens: chunk.Usage.CompletionTokens,
227 TotalTokens: chunk.Usage.TotalTokens,
228 ReasoningTokens: completionTokenDetails.ReasoningTokens,
229 CacheReadTokens: promptTokenDetails.CachedTokens,
230 }
231
232 // Add prediction tokens if available
233 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
234 if completionTokenDetails.AcceptedPredictionTokens > 0 {
235 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
236 }
237 if completionTokenDetails.RejectedPredictionTokens > 0 {
238 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
239 }
240 }
241
242 return usage, ai.ProviderMetadata{
243 Name: streamProviderMetadata,
244 }
245}
246
247func defaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
248 streamProviderMetadata, ok := metadata[Name]
249 if !ok {
250 streamProviderMetadata = &ProviderMetadata{}
251 }
252 if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
253 converted.Logprobs = choice.Logprobs.Content
254 metadata[Name] = converted
255 }
256 return metadata
257}