1package openai
2
3import (
4 "fmt"
5
6 "charm.land/fantasy"
7 "github.com/openai/openai-go/v2"
8 "github.com/openai/openai-go/v2/packages/param"
9 "github.com/openai/openai-go/v2/shared"
10)
11
12// LanguageModelPrepareCallFunc is a function that prepares the call for the language model.
13type LanguageModelPrepareCallFunc = func(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error)
14
15// LanguageModelMapFinishReasonFunc is a function that maps the finish reason for the language model.
16type LanguageModelMapFinishReasonFunc = func(finishReason string) fantasy.FinishReason
17
18// LanguageModelUsageFunc is a function that calculates usage for the language model.
19type LanguageModelUsageFunc = func(choice openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData)
20
21// LanguageModelExtraContentFunc is a function that adds extra content for the language model.
22type LanguageModelExtraContentFunc = func(choice openai.ChatCompletionChoice) []fantasy.Content
23
24// LanguageModelStreamExtraFunc is a function that handles stream extra functionality for the language model.
25type LanguageModelStreamExtraFunc = func(chunk openai.ChatCompletionChunk, yield func(fantasy.StreamPart) bool, ctx map[string]any) (map[string]any, bool)
26
27// LanguageModelStreamUsageFunc is a function that calculates stream usage for the language model.
28type LanguageModelStreamUsageFunc = func(chunk openai.ChatCompletionChunk, ctx map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata)
29
30// LanguageModelStreamProviderMetadataFunc is a function that handles stream provider metadata for the language model.
31type LanguageModelStreamProviderMetadataFunc = func(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata
32
33// DefaultPrepareCallFunc is the default implementation for preparing a call to the language model.
34func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatCompletionNewParams, call fantasy.Call) ([]fantasy.CallWarning, error) {
35 if call.ProviderOptions == nil {
36 return nil, nil
37 }
38 var warnings []fantasy.CallWarning
39 providerOptions := &ProviderOptions{}
40 if v, ok := call.ProviderOptions[Name]; ok {
41 providerOptions, ok = v.(*ProviderOptions)
42 if !ok {
43 return nil, fantasy.NewInvalidArgumentError("providerOptions", "openai provider options should be *openai.ProviderOptions", nil)
44 }
45 }
46
47 if providerOptions.LogitBias != nil {
48 params.LogitBias = providerOptions.LogitBias
49 }
50 if providerOptions.LogProbs != nil && providerOptions.TopLogProbs != nil {
51 providerOptions.LogProbs = nil
52 }
53 if providerOptions.LogProbs != nil {
54 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
55 }
56 if providerOptions.TopLogProbs != nil {
57 params.TopLogprobs = param.NewOpt(*providerOptions.TopLogProbs)
58 }
59 if providerOptions.User != nil {
60 params.User = param.NewOpt(*providerOptions.User)
61 }
62 if providerOptions.ParallelToolCalls != nil {
63 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
64 }
65 if providerOptions.MaxCompletionTokens != nil {
66 params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens)
67 }
68
69 if providerOptions.TextVerbosity != nil {
70 params.Verbosity = openai.ChatCompletionNewParamsVerbosity(*providerOptions.TextVerbosity)
71 }
72 if providerOptions.Prediction != nil {
73 // Convert map[string]any to ChatCompletionPredictionContentParam
74 if content, ok := providerOptions.Prediction["content"]; ok {
75 if contentStr, ok := content.(string); ok {
76 params.Prediction = openai.ChatCompletionPredictionContentParam{
77 Content: openai.ChatCompletionPredictionContentContentUnionParam{
78 OfString: param.NewOpt(contentStr),
79 },
80 }
81 }
82 }
83 }
84 if providerOptions.Store != nil {
85 params.Store = param.NewOpt(*providerOptions.Store)
86 }
87 if providerOptions.Metadata != nil {
88 // Convert map[string]any to map[string]string
89 metadata := make(map[string]string)
90 for k, v := range providerOptions.Metadata {
91 if str, ok := v.(string); ok {
92 metadata[k] = str
93 }
94 }
95 params.Metadata = metadata
96 }
97 if providerOptions.PromptCacheKey != nil {
98 params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey)
99 }
100 if providerOptions.SafetyIdentifier != nil {
101 params.SafetyIdentifier = param.NewOpt(*providerOptions.SafetyIdentifier)
102 }
103 if providerOptions.ServiceTier != nil {
104 params.ServiceTier = openai.ChatCompletionNewParamsServiceTier(*providerOptions.ServiceTier)
105 }
106
107 if providerOptions.ReasoningEffort != nil {
108 switch *providerOptions.ReasoningEffort {
109 case ReasoningEffortMinimal:
110 params.ReasoningEffort = shared.ReasoningEffortMinimal
111 case ReasoningEffortLow:
112 params.ReasoningEffort = shared.ReasoningEffortLow
113 case ReasoningEffortMedium:
114 params.ReasoningEffort = shared.ReasoningEffortMedium
115 case ReasoningEffortHigh:
116 params.ReasoningEffort = shared.ReasoningEffortHigh
117 default:
118 return nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
119 }
120 }
121
122 if isReasoningModel(model.Model()) {
123 if providerOptions.LogitBias != nil {
124 params.LogitBias = nil
125 warnings = append(warnings, fantasy.CallWarning{
126 Type: fantasy.CallWarningTypeUnsupportedSetting,
127 Setting: "LogitBias",
128 Message: "LogitBias is not supported for reasoning models",
129 })
130 }
131 if providerOptions.LogProbs != nil {
132 params.Logprobs = param.Opt[bool]{}
133 warnings = append(warnings, fantasy.CallWarning{
134 Type: fantasy.CallWarningTypeUnsupportedSetting,
135 Setting: "Logprobs",
136 Message: "Logprobs is not supported for reasoning models",
137 })
138 }
139 if providerOptions.TopLogProbs != nil {
140 params.TopLogprobs = param.Opt[int64]{}
141 warnings = append(warnings, fantasy.CallWarning{
142 Type: fantasy.CallWarningTypeUnsupportedSetting,
143 Setting: "TopLogprobs",
144 Message: "TopLogprobs is not supported for reasoning models",
145 })
146 }
147 }
148
149 // Handle service tier validation
150 if providerOptions.ServiceTier != nil {
151 serviceTier := *providerOptions.ServiceTier
152 if serviceTier == "flex" && !supportsFlexProcessing(model.Model()) {
153 params.ServiceTier = ""
154 warnings = append(warnings, fantasy.CallWarning{
155 Type: fantasy.CallWarningTypeUnsupportedSetting,
156 Setting: "ServiceTier",
157 Details: "flex processing is only available for o3, o4-mini, and gpt-5 models",
158 })
159 } else if serviceTier == "priority" && !supportsPriorityProcessing(model.Model()) {
160 params.ServiceTier = ""
161 warnings = append(warnings, fantasy.CallWarning{
162 Type: fantasy.CallWarningTypeUnsupportedSetting,
163 Setting: "ServiceTier",
164 Details: "priority processing is only available for supported models (gpt-4, gpt-5, gpt-5-mini, o3, o4-mini) and requires Enterprise access. gpt-5-nano is not supported",
165 })
166 }
167 }
168 return warnings, nil
169}
170
171// DefaultMapFinishReasonFunc is the default implementation for mapping finish reasons.
172func DefaultMapFinishReasonFunc(finishReason string) fantasy.FinishReason {
173 switch finishReason {
174 case "stop":
175 return fantasy.FinishReasonStop
176 case "length":
177 return fantasy.FinishReasonLength
178 case "content_filter":
179 return fantasy.FinishReasonContentFilter
180 case "function_call", "tool_calls":
181 return fantasy.FinishReasonToolCalls
182 default:
183 return fantasy.FinishReasonUnknown
184 }
185}
186
187// DefaultUsageFunc is the default implementation for calculating usage.
188func DefaultUsageFunc(response openai.ChatCompletion) (fantasy.Usage, fantasy.ProviderOptionsData) {
189 completionTokenDetails := response.Usage.CompletionTokensDetails
190 promptTokenDetails := response.Usage.PromptTokensDetails
191
192 // Build provider metadata
193 providerMetadata := &ProviderMetadata{}
194
195 // Add logprobs if available
196 if len(response.Choices) > 0 && len(response.Choices[0].Logprobs.Content) > 0 {
197 providerMetadata.Logprobs = response.Choices[0].Logprobs.Content
198 }
199
200 // Add prediction tokens if available
201 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
202 if completionTokenDetails.AcceptedPredictionTokens > 0 {
203 providerMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
204 }
205 if completionTokenDetails.RejectedPredictionTokens > 0 {
206 providerMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
207 }
208 }
209 return fantasy.Usage{
210 InputTokens: response.Usage.PromptTokens,
211 OutputTokens: response.Usage.CompletionTokens,
212 TotalTokens: response.Usage.TotalTokens,
213 ReasoningTokens: completionTokenDetails.ReasoningTokens,
214 CacheReadTokens: promptTokenDetails.CachedTokens,
215 }, providerMetadata
216}
217
218// DefaultStreamUsageFunc is the default implementation for calculating stream usage.
219func DefaultStreamUsageFunc(chunk openai.ChatCompletionChunk, _ map[string]any, metadata fantasy.ProviderMetadata) (fantasy.Usage, fantasy.ProviderMetadata) {
220 if chunk.Usage.TotalTokens == 0 {
221 return fantasy.Usage{}, nil
222 }
223 streamProviderMetadata := &ProviderMetadata{}
224 if metadata != nil {
225 if providerMetadata, ok := metadata[Name]; ok {
226 converted, ok := providerMetadata.(*ProviderMetadata)
227 if ok {
228 streamProviderMetadata = converted
229 }
230 }
231 }
232 // we do this here because the acc does not add prompt details
233 completionTokenDetails := chunk.Usage.CompletionTokensDetails
234 promptTokenDetails := chunk.Usage.PromptTokensDetails
235 usage := fantasy.Usage{
236 InputTokens: chunk.Usage.PromptTokens,
237 OutputTokens: chunk.Usage.CompletionTokens,
238 TotalTokens: chunk.Usage.TotalTokens,
239 ReasoningTokens: completionTokenDetails.ReasoningTokens,
240 CacheReadTokens: promptTokenDetails.CachedTokens,
241 }
242
243 // Add prediction tokens if available
244 if completionTokenDetails.AcceptedPredictionTokens > 0 || completionTokenDetails.RejectedPredictionTokens > 0 {
245 if completionTokenDetails.AcceptedPredictionTokens > 0 {
246 streamProviderMetadata.AcceptedPredictionTokens = completionTokenDetails.AcceptedPredictionTokens
247 }
248 if completionTokenDetails.RejectedPredictionTokens > 0 {
249 streamProviderMetadata.RejectedPredictionTokens = completionTokenDetails.RejectedPredictionTokens
250 }
251 }
252
253 return usage, fantasy.ProviderMetadata{
254 Name: streamProviderMetadata,
255 }
256}
257
258// DefaultStreamProviderMetadataFunc is the default implementation for handling stream provider metadata.
259func DefaultStreamProviderMetadataFunc(choice openai.ChatCompletionChoice, metadata fantasy.ProviderMetadata) fantasy.ProviderMetadata {
260 streamProviderMetadata, ok := metadata[Name]
261 if !ok {
262 streamProviderMetadata = &ProviderMetadata{}
263 }
264 if converted, ok := streamProviderMetadata.(*ProviderMetadata); ok {
265 converted.Logprobs = choice.Logprobs.Content
266 metadata[Name] = converted
267 }
268 return metadata
269}