1package openrouter
2
3import (
4 "encoding/json"
5 "fmt"
6 "maps"
7
8 "github.com/charmbracelet/fantasy/ai"
9 "github.com/charmbracelet/fantasy/anthropic"
10 openaisdk "github.com/openai/openai-go/v2"
11 "github.com/openai/openai-go/v2/packages/param"
12)
13
14const reasoningStartedCtx = "reasoning_started"
15
16func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
17 providerOptions := &ProviderOptions{}
18 if v, ok := call.ProviderOptions[Name]; ok {
19 providerOptions, ok = v.(*ProviderOptions)
20 if !ok {
21 return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
22 }
23 }
24
25 extraFields := make(map[string]any)
26
27 if providerOptions.Provider != nil {
28 data, err := structToMapJSON(providerOptions.Provider)
29 if err != nil {
30 return nil, err
31 }
32 extraFields["provider"] = data
33 }
34
35 if providerOptions.Reasoning != nil {
36 data, err := structToMapJSON(providerOptions.Reasoning)
37 if err != nil {
38 return nil, err
39 }
40 extraFields["reasoning"] = data
41 }
42
43 if providerOptions.IncludeUsage != nil {
44 extraFields["usage"] = map[string]any{
45 "include": *providerOptions.IncludeUsage,
46 }
47 } else { // default include usage
48 extraFields["usage"] = map[string]any{
49 "include": true,
50 }
51 }
52 if providerOptions.LogitBias != nil {
53 params.LogitBias = providerOptions.LogitBias
54 }
55 if providerOptions.LogProbs != nil {
56 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
57 }
58 if providerOptions.User != nil {
59 params.User = param.NewOpt(*providerOptions.User)
60 }
61 if providerOptions.ParallelToolCalls != nil {
62 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
63 }
64
65 maps.Copy(extraFields, providerOptions.ExtraBody)
66 params.SetExtraFields(extraFields)
67 return nil, nil
68}
69
70func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
71 finishReason := choice.FinishReason
72 switch finishReason {
73 case "stop":
74 return ai.FinishReasonStop
75 case "length":
76 return ai.FinishReasonLength
77 case "content_filter":
78 return ai.FinishReasonContentFilter
79 case "function_call", "tool_calls":
80 return ai.FinishReasonToolCalls
81 default:
82 // for streaming responses the openai accumulator is not working as expected with some provider
83 // therefore it is sending no finish reason so we need to manually handle it
84 if len(choice.Message.ToolCalls) > 0 {
85 return ai.FinishReasonToolCalls
86 } else if finishReason == "" {
87 return ai.FinishReasonStop
88 }
89 return ai.FinishReasonUnknown
90 }
91}
92
93func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
94 var content []ai.Content
95 reasoningData := ReasoningData{}
96 err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
97 if err != nil {
98 return content
99 }
100 for _, detail := range reasoningData.ReasoningDetails {
101 var metadata ai.ProviderMetadata
102
103 if detail.Signature != "" {
104 metadata = ai.ProviderMetadata{
105 Name: &ReasoningMetadata{
106 Signature: detail.Signature,
107 },
108 anthropic.Name: &anthropic.ReasoningOptionMetadata{
109 Signature: detail.Signature,
110 },
111 }
112 }
113 switch detail.Type {
114 case "reasoning.text":
115 content = append(content, ai.ReasoningContent{
116 Text: detail.Text,
117 ProviderMetadata: metadata,
118 })
119 case "reasoning.summary":
120 content = append(content, ai.ReasoningContent{
121 Text: detail.Summary,
122 ProviderMetadata: metadata,
123 })
124 case "reasoning.encrypted":
125 content = append(content, ai.ReasoningContent{
126 Text: "[REDACTED]",
127 ProviderMetadata: metadata,
128 })
129 }
130 }
131 return content
132}
133
134func extractReasoningContext(ctx map[string]any) bool {
135 reasoningStarted, ok := ctx[reasoningStartedCtx]
136 if !ok {
137 return false
138 }
139 b, ok := reasoningStarted.(bool)
140 if !ok {
141 return false
142 }
143 return b
144}
145
146func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
147 if len(chunk.Choices) == 0 {
148 return ctx, true
149 }
150
151 reasoningStarted := extractReasoningContext(ctx)
152
153 for inx, choice := range chunk.Choices {
154 reasoningData := ReasoningData{}
155 err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
156 if err != nil {
157 yield(ai.StreamPart{
158 Type: ai.StreamPartTypeError,
159 Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
160 })
161 return ctx, false
162 }
163
164 emitEvent := func(reasoningContent string, signature string) bool {
165 if !reasoningStarted {
166 shouldContinue := yield(ai.StreamPart{
167 Type: ai.StreamPartTypeReasoningStart,
168 ID: fmt.Sprintf("%d", inx),
169 })
170 if !shouldContinue {
171 return false
172 }
173 }
174
175 var metadata ai.ProviderMetadata
176
177 if signature != "" {
178 metadata = ai.ProviderMetadata{
179 Name: &ReasoningMetadata{
180 Signature: signature,
181 },
182 anthropic.Name: &anthropic.ReasoningOptionMetadata{
183 Signature: signature,
184 },
185 }
186 }
187
188 return yield(ai.StreamPart{
189 Type: ai.StreamPartTypeReasoningDelta,
190 ID: fmt.Sprintf("%d", inx),
191 Delta: reasoningContent,
192 ProviderMetadata: metadata,
193 })
194 }
195 if len(reasoningData.ReasoningDetails) > 0 {
196 for _, detail := range reasoningData.ReasoningDetails {
197 if !reasoningStarted {
198 ctx[reasoningStartedCtx] = true
199 }
200 switch detail.Type {
201 case "reasoning.text":
202 return ctx, emitEvent(detail.Text, detail.Signature)
203 case "reasoning.summary":
204 return ctx, emitEvent(detail.Summary, detail.Signature)
205 case "reasoning.encrypted":
206 return ctx, emitEvent("[REDACTED]", detail.Signature)
207 }
208 }
209 } else if reasoningData.Reasoning != "" {
210 return ctx, emitEvent(reasoningData.Reasoning, "")
211 }
212 if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
213 ctx[reasoningStartedCtx] = false
214 return ctx, yield(ai.StreamPart{
215 Type: ai.StreamPartTypeReasoningEnd,
216 ID: fmt.Sprintf("%d", inx),
217 })
218 }
219 }
220 return ctx, true
221}
222
223func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
224 if len(response.Choices) == 0 {
225 return ai.Usage{}, nil
226 }
227 openrouterUsage := UsageAccounting{}
228 usage := response.Usage
229
230 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
231
232 completionTokenDetails := usage.CompletionTokensDetails
233 promptTokenDetails := usage.PromptTokensDetails
234
235 var provider string
236 if p, ok := response.JSON.ExtraFields["provider"]; ok {
237 provider = p.Raw()
238 }
239
240 // Build provider metadata
241 providerMetadata := &ProviderMetadata{
242 Provider: provider,
243 Usage: openrouterUsage,
244 }
245
246 return ai.Usage{
247 InputTokens: usage.PromptTokens,
248 OutputTokens: usage.CompletionTokens,
249 TotalTokens: usage.TotalTokens,
250 ReasoningTokens: completionTokenDetails.ReasoningTokens,
251 CacheReadTokens: promptTokenDetails.CachedTokens,
252 }, providerMetadata
253}
254
255func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
256 usage := chunk.Usage
257 if usage.TotalTokens == 0 {
258 return ai.Usage{}, nil
259 }
260
261 streamProviderMetadata := &ProviderMetadata{}
262 if metadata != nil {
263 if providerMetadata, ok := metadata[Name]; ok {
264 converted, ok := providerMetadata.(*ProviderMetadata)
265 if ok {
266 streamProviderMetadata = converted
267 }
268 }
269 }
270 openrouterUsage := UsageAccounting{}
271 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
272 streamProviderMetadata.Usage = openrouterUsage
273
274 if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
275 streamProviderMetadata.Provider = p.Raw()
276 }
277
278 // we do this here because the acc does not add prompt details
279 completionTokenDetails := usage.CompletionTokensDetails
280 promptTokenDetails := usage.PromptTokensDetails
281 aiUsage := ai.Usage{
282 InputTokens: usage.PromptTokens,
283 OutputTokens: usage.CompletionTokens,
284 TotalTokens: usage.TotalTokens,
285 ReasoningTokens: completionTokenDetails.ReasoningTokens,
286 CacheReadTokens: promptTokenDetails.CachedTokens,
287 }
288
289 return aiUsage, ai.ProviderMetadata{
290 Name: streamProviderMetadata,
291 }
292}