1package openrouter
2
3import (
4 "encoding/json"
5 "fmt"
6 "maps"
7
8 "github.com/charmbracelet/fantasy/ai"
9 "github.com/charmbracelet/fantasy/anthropic"
10 openaisdk "github.com/openai/openai-go/v2"
11 "github.com/openai/openai-go/v2/packages/param"
12)
13
14const reasoningStartedCtx = "reasoning_started"
15
16func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
17 providerOptions := &ProviderOptions{}
18 if v, ok := call.ProviderOptions[Name]; ok {
19 providerOptions, ok = v.(*ProviderOptions)
20 if !ok {
21 return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
22 }
23 }
24
25 extraFields := make(map[string]any)
26
27 if providerOptions.Provider != nil {
28 data, err := structToMapJSON(providerOptions.Provider)
29 if err != nil {
30 return nil, err
31 }
32 extraFields["provider"] = data
33 }
34
35 if providerOptions.Reasoning != nil {
36 data, err := structToMapJSON(providerOptions.Reasoning)
37 if err != nil {
38 return nil, err
39 }
40 extraFields["reasoning"] = data
41 }
42
43 if providerOptions.IncludeUsage != nil {
44 extraFields["usage"] = map[string]any{
45 "include": *providerOptions.IncludeUsage,
46 }
47 } else { // default include usage
48 extraFields["usage"] = map[string]any{
49 "include": true,
50 }
51 }
52 if providerOptions.LogitBias != nil {
53 params.LogitBias = providerOptions.LogitBias
54 }
55 if providerOptions.LogProbs != nil {
56 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
57 }
58 if providerOptions.User != nil {
59 params.User = param.NewOpt(*providerOptions.User)
60 }
61 if providerOptions.ParallelToolCalls != nil {
62 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
63 }
64
65 maps.Copy(extraFields, providerOptions.ExtraBody)
66 params.SetExtraFields(extraFields)
67 return nil, nil
68}
69
70func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
71 finishReason := choice.FinishReason
72 switch finishReason {
73 case "stop":
74 return ai.FinishReasonStop
75 case "length":
76 return ai.FinishReasonLength
77 case "content_filter":
78 return ai.FinishReasonContentFilter
79 case "function_call", "tool_calls":
80 return ai.FinishReasonToolCalls
81 default:
82 // for streaming responses the openai accumulator is not working as expected with some provider
83 // therefore it is sending no finish reason so we need to manually handle it
84 if len(choice.Message.ToolCalls) > 0 {
85 return ai.FinishReasonToolCalls
86 } else if finishReason == "" {
87 return ai.FinishReasonStop
88 }
89 return ai.FinishReasonUnknown
90 }
91}
92
93func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
94 var content []ai.Content
95 reasoningData := ReasoningData{}
96 err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
97 if err != nil {
98 return content
99 }
100 for _, detail := range reasoningData.ReasoningDetails {
101
102 var metadata ai.ProviderMetadata
103
104 if detail.Signature != "" {
105 metadata = ai.ProviderMetadata{
106 Name: &ReasoningMetadata{
107 Signature: detail.Signature,
108 },
109 anthropic.Name: &anthropic.ReasoningOptionMetadata{
110 Signature: detail.Signature,
111 },
112 }
113 }
114 switch detail.Type {
115 case "reasoning.text":
116 content = append(content, ai.ReasoningContent{
117 Text: detail.Text,
118 ProviderMetadata: metadata,
119 })
120 case "reasoning.summary":
121 content = append(content, ai.ReasoningContent{
122 Text: detail.Summary,
123 ProviderMetadata: metadata,
124 })
125 case "reasoning.encrypted":
126 content = append(content, ai.ReasoningContent{
127 Text: "[REDACTED]",
128 ProviderMetadata: metadata,
129 })
130 }
131 }
132 return content
133}
134
135func extractReasoningContext(ctx map[string]any) bool {
136 reasoningStarted, ok := ctx[reasoningStartedCtx]
137 if !ok {
138 return false
139 }
140 b, ok := reasoningStarted.(bool)
141 if !ok {
142 return false
143 }
144 return b
145}
146
147func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
148 if len(chunk.Choices) == 0 {
149 return ctx, true
150 }
151
152 reasoningStarted := extractReasoningContext(ctx)
153
154 for inx, choice := range chunk.Choices {
155 reasoningData := ReasoningData{}
156 err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
157 if err != nil {
158 yield(ai.StreamPart{
159 Type: ai.StreamPartTypeError,
160 Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
161 })
162 return ctx, false
163 }
164
165 emitEvent := func(reasoningContent string, signature string) bool {
166 if !reasoningStarted {
167 shouldContinue := yield(ai.StreamPart{
168 Type: ai.StreamPartTypeReasoningStart,
169 ID: fmt.Sprintf("%d", inx),
170 })
171 if !shouldContinue {
172 return false
173 }
174 }
175
176 var metadata ai.ProviderMetadata
177
178 if signature != "" {
179 metadata = ai.ProviderMetadata{
180 Name: &ReasoningMetadata{
181 Signature: signature,
182 },
183 anthropic.Name: &anthropic.ReasoningOptionMetadata{
184 Signature: signature,
185 },
186 }
187 }
188
189 return yield(ai.StreamPart{
190 Type: ai.StreamPartTypeReasoningDelta,
191 ID: fmt.Sprintf("%d", inx),
192 Delta: reasoningContent,
193 ProviderMetadata: metadata,
194 })
195 }
196 if len(reasoningData.ReasoningDetails) > 0 {
197 for _, detail := range reasoningData.ReasoningDetails {
198 if !reasoningStarted {
199 ctx[reasoningStartedCtx] = true
200 }
201 switch detail.Type {
202 case "reasoning.text":
203 return ctx, emitEvent(detail.Text, detail.Signature)
204 case "reasoning.summary":
205 return ctx, emitEvent(detail.Summary, detail.Signature)
206 case "reasoning.encrypted":
207 return ctx, emitEvent("[REDACTED]", detail.Signature)
208 }
209 }
210 } else if reasoningData.Reasoning != "" {
211 return ctx, emitEvent(reasoningData.Reasoning, "")
212 }
213 if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
214 ctx[reasoningStartedCtx] = false
215 return ctx, yield(ai.StreamPart{
216 Type: ai.StreamPartTypeReasoningEnd,
217 ID: fmt.Sprintf("%d", inx),
218 })
219 }
220 }
221 return ctx, true
222}
223
224func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
225 if len(response.Choices) == 0 {
226 return ai.Usage{}, nil
227 }
228 openrouterUsage := UsageAccounting{}
229 usage := response.Usage
230
231 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
232
233 completionTokenDetails := usage.CompletionTokensDetails
234 promptTokenDetails := usage.PromptTokensDetails
235
236 var provider string
237 if p, ok := response.JSON.ExtraFields["provider"]; ok {
238 provider = p.Raw()
239 }
240
241 // Build provider metadata
242 providerMetadata := &ProviderMetadata{
243 Provider: provider,
244 Usage: openrouterUsage,
245 }
246
247 return ai.Usage{
248 InputTokens: usage.PromptTokens,
249 OutputTokens: usage.CompletionTokens,
250 TotalTokens: usage.TotalTokens,
251 ReasoningTokens: completionTokenDetails.ReasoningTokens,
252 CacheReadTokens: promptTokenDetails.CachedTokens,
253 }, providerMetadata
254}
255
256func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
257 usage := chunk.Usage
258 if usage.TotalTokens == 0 {
259 return ai.Usage{}, nil
260 }
261
262 streamProviderMetadata := &ProviderMetadata{}
263 if metadata != nil {
264 if providerMetadata, ok := metadata[Name]; ok {
265 converted, ok := providerMetadata.(*ProviderMetadata)
266 if ok {
267 streamProviderMetadata = converted
268 }
269 }
270 }
271 openrouterUsage := UsageAccounting{}
272 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
273 streamProviderMetadata.Usage = openrouterUsage
274
275 if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
276 streamProviderMetadata.Provider = p.Raw()
277 }
278
279 // we do this here because the acc does not add prompt details
280 completionTokenDetails := usage.CompletionTokensDetails
281 promptTokenDetails := usage.PromptTokensDetails
282 aiUsage := ai.Usage{
283 InputTokens: usage.PromptTokens,
284 OutputTokens: usage.CompletionTokens,
285 TotalTokens: usage.TotalTokens,
286 ReasoningTokens: completionTokenDetails.ReasoningTokens,
287 CacheReadTokens: promptTokenDetails.CachedTokens,
288 }
289
290 return aiUsage, ai.ProviderMetadata{
291 Name: streamProviderMetadata,
292 }
293}