1package openai
2
3import (
4 "fmt"
5
6 "github.com/charmbracelet/fantasy/ai"
7 "github.com/openai/openai-go/v2"
8 "github.com/openai/openai-go/v2/packages/param"
9 "github.com/openai/openai-go/v2/shared"
10)
11
12type (
13 LanguageModelPrepareCallFunc = func(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error)
14 LanguageModelMapFinishReasonFunc = func(finishReason string) ai.FinishReason
15 LanguageModelUsageFunc = func(choice openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData)
16 LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []ai.Content
17 LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
18 LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata)
19 LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata
20)
21
22func DefaultPrepareCallFunc(model ai.LanguageModel, params *openai.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
23 if call.ProviderOptions == nil {
24 return nil, nil
25 }
26 var warnings []ai.CallWarning
27 providerOptions := &ProviderOptions{}
28 if v, ok := call.ProviderOptions[Name]; ok {
29 providerOptions, ok = v.(*ProviderOptions)
30 if !ok {
31 return nil, ai.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
32 }
33 }
34
35 if providerOptions.LogitBias != nil {
36 params.LogitBias = providerOptions.LogitBias
37 }
38 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
39 providerOptions.LogProbs = nil
40 }
41 if providerOptions.LogProbs != nil {
42 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
43 }
44 if providerOptions.TopLogProbs != nil {
45 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
46 }
47 if providerOptions.User != nil {
48 params.User = param.NewOpt(*providerOptions.User)
49 }
50 if providerOptions.ParallelToolCalls != nil {
51 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
52 }
53 if providerOptions.MaxCompletionTokens != nil {
54 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
55 }
56
57 if providerOptions.TextVerbosity != nil {
58 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
59 }
60 if providerOptions.Prediction != nil {
61 // Convert map[string]any to ChatCompletionPredictionContentParam
62 if content, ok := providerOptions.Prediction["content"]; ok {
63 if contentStr, ok := content.(string); ok {
64 params.Prediction = openai.ChatCompletionPredictionContentParam{
65 Content: openai.ChatCompletionPredictionContentContentUnionParam{
66 OfString: param.NewOpt(contentStr),
67 },
68 }
69 }
70 }
71 }
72 if providerOptions.Store != nil {
73 params.Store = param.NewOpt(*providerOptions.Store)
74 }
75 if providerOptions.Metadata != nil {
76 // Convert map[string]any to map[string]string
77 metadata := make(map[string]string)
78 for k, v := range providerOptions.Metadata {
79 if str, ok := v.(string); ok {
80 metadata[k] = str
81 }
82 }
83 params.Metadata = metadata
84 }
85 if providerOptions.PromptCacheKey != nil {
86 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
87 }
88 if providerOptions.SafetyIdentifier != nil {
89 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
90 }
91 if providerOptions.ServiceTier != nil {
92 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
93 }
94
95 if providerOptions.ReasoningEffort != nil {
96 switch *providerOptions.ReasoningEffort {
97 case ReasoningEffortMinimal:
98 params.ReasoningEffort = shared.ReasoningEffortMinimal
99 case ReasoningEffortLow:
100 params.ReasoningEffort = shared.ReasoningEffortLow
101 case ReasoningEffortMedium:
102 params.ReasoningEffort = shared.ReasoningEffortMedium
103 case ReasoningEffortHigh:
104 params.ReasoningEffort = shared.ReasoningEffortHigh
105 default:
106 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
107 }
108 }
109
110 if isReasoningModel(model.Model()) {
111 if providerOptions.LogitBias != nil {
112 params.LogitBias = nil
113 warnings = append(warnings, ai.CallWarning{
114 Type: ai.CallWarningTypeUnsupportedSetting,
115 Setting: "LogitBias",
116 Message: "LogitBias is not supported for reasoning models",
117 })
118 }
119 if providerOptions.LogProbs != nil {
120 params.Logprobs = param.Opt[bool]{}
121 warnings = append(warnings, ai.CallWarning{
122 Type: ai.CallWarningTypeUnsupportedSetting,
123 Setting: "Logprobs",
124 Message: "Logprobs is not supported for reasoning models",
125 })
126 }
127 if providerOptions.TopLogProbs != nil {
128 params.TopLogprobs = param.Opt[int64]{}
129 warnings = append(warnings, ai.CallWarning{
130 Type: ai.CallWarningTypeUnsupportedSetting,
131 Setting: "TopLogprobs",
132 Message: "TopLogprobs is not supported for reasoning models",
133 })
134 }
135 }
136
137 // Handle service tier validation
138 if providerOptions.ServiceTier != nil {
139 serviceTier := *providerOptions.ServiceTier
140 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
141 params.ServiceTier = ""
142 warnings = append(warnings, ai.CallWarning{
143 Type: ai.CallWarningTypeUnsupportedSetting,
144 Setting: "ServiceTier",
145 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
146 })
147 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
148 params.ServiceTier = ""
149 warnings = append(warnings, ai.CallWarning{
150 Type: ai.CallWarningTypeUnsupportedSetting,
151 Setting: "ServiceTier",
152 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
153 })
154 }
155 }
156 return warnings, nil
157}
158
159func DefaultMapFinishReasonFunc(finishReason string) ai.FinishReason {
160 switch finishReason {
161 case "stop":
162 return ai.FinishReasonStop
163 case "length":
164 return ai.FinishReasonLength
165 case "content_filter":
166 return ai.FinishReasonContentFilter
167 case "function_call", "tool_calls":
168 return ai.FinishReasonToolCalls
169 default:
170 return ai.FinishReasonUnknown
171 }
172}
173
174func DefaultUsageFunc(response openai.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
175 completionTokenDetails := response.Usage.CompletionTokensDetails
176 promptTokenDetails := response.Usage.PromptTokensDetails
177
178 // Build provider metadata
179 providerMetadata := &ProviderMetadata{}
180
181 // Add logprobs if available
182 if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
183 providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
184 }
185
186 // Add prediction tokens if available
187 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
188 if completionTokenDetails.AcceptedPredictionTokens > 0 {
189 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
190 }
191 if completionTokenDetails.RejectedPredictionTokens > 0 {
192 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
193 }
194 }
195 return ai.Usage{
196 InputTokens: response.Usage.PromptTokens,
197 OutputTokens: response.Usage.CompletionTokens,
198 TotalTokens: response.Usage.TotalTokens,
199 ReasoningTokens: completionTokenDetails.ReasoningTokens,
200 CacheReadTokens: promptTokenDetails.CachedTokens,
201 }, providerMetadata
202}
203
204func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
205 if chunk.Usage.TotalTokens == 0 {
206 return ai.Usage{}, nil
207 }
208 streamProviderMetadata := &ProviderMetadata{}
209 if metadata != nil {
210 if providerMetadata, ok := metadata[Name]; ok {
211 converted, ok := providerMetadata.(*ProviderMetadata)
212 if ok {
213 streamProviderMetadata = converted
214 }
215 }
216 }
217 // we do this here because the acc does not add prompt details
218 completionTokenDetails := chunk.Usage.CompletionTokensDetails
219 promptTokenDetails := chunk.Usage.PromptTokensDetails
220 usage := ai.Usage{
221 InputTokens: chunk.Usage.PromptTokens,
222 OutputTokens: chunk.Usage.CompletionTokens,
223 TotalTokens: chunk.Usage.TotalTokens,
224 ReasoningTokens: completionTokenDetails.ReasoningTokens,
225 CacheReadTokens: promptTokenDetails.CachedTokens,
226 }
227
228 // Add prediction tokens if available
229 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
230 if completionTokenDetails.AcceptedPredictionTokens > 0 {
231 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
232 }
233 if completionTokenDetails.RejectedPredictionTokens > 0 {
234 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
235 }
236 }
237
238 return usage, ai.ProviderMetadata{
239 Name: streamProviderMetadata,
240 }
241}
242
243func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata ai.ProviderMetadata) ai.ProviderMetadata {
244 streamProviderMetadata, ok := metadata[Name]
245 if !ok {
246 streamProviderMetadata = &ProviderMetadata{}
247 }
248 if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
249 converted.Logprobs = choice.Logprobs.Content
250 metadata[Name] = converted
251 }
252 return metadata
253}