1package openrouter
2
3import (
4 "encoding/json"
5 "fmt"
6 "maps"
7
8 "github.com/charmbracelet/fantasy/ai"
9 openaisdk "github.com/openai/openai-go/v2"
10 "github.com/openai/openai-go/v2/packages/param"
11)
12
13const reasoningStartedCtx = "reasoning_started"
14
15func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
16 providerOptions := &ProviderOptions{}
17 if v, ok := call.ProviderOptions[Name]; ok {
18 providerOptions, ok = v.(*ProviderOptions)
19 if !ok {
20 return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
21 }
22 }
23
24 extraFields := make(map[string]any)
25
26 if providerOptions.Provider != nil {
27 data, err := structToMapJSON(providerOptions.Provider)
28 if err != nil {
29 return nil, err
30 }
31 extraFields["provider"] = data
32 }
33
34 if providerOptions.Reasoning != nil {
35 data, err := structToMapJSON(providerOptions.Reasoning)
36 if err != nil {
37 return nil, err
38 }
39 extraFields["reasoning"] = data
40 }
41
42 if providerOptions.IncludeUsage != nil {
43 extraFields["usage"] = map[string]any{
44 "include": *providerOptions.IncludeUsage,
45 }
46 } else { // default include usage
47 extraFields["usage"] = map[string]any{
48 "include": true,
49 }
50 }
51 if providerOptions.LogitBias != nil {
52 params.LogitBias = providerOptions.LogitBias
53 }
54 if providerOptions.LogProbs != nil {
55 params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
56 }
57 if providerOptions.User != nil {
58 params.User = param.NewOpt(*providerOptions.User)
59 }
60 if providerOptions.ParallelToolCalls != nil {
61 params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
62 }
63
64 maps.Copy(extraFields, providerOptions.ExtraBody)
65 params.SetExtraFields(extraFields)
66 return nil, nil
67}
68
69func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
70 finishReason := choice.FinishReason
71 switch finishReason {
72 case "stop":
73 return ai.FinishReasonStop
74 case "length":
75 return ai.FinishReasonLength
76 case "content_filter":
77 return ai.FinishReasonContentFilter
78 case "function_call", "tool_calls":
79 return ai.FinishReasonToolCalls
80 default:
81 // for streaming responses the openai accumulator is not working as expected with some provider
82 // therefore it is sending no finish reason so we need to manually handle it
83 if len(choice.Message.ToolCalls) > 0 {
84 return ai.FinishReasonToolCalls
85 } else if finishReason == "" {
86 return ai.FinishReasonStop
87 }
88 return ai.FinishReasonUnknown
89 }
90}
91
92func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
93 var content []ai.Content
94 reasoningData := ReasoningData{}
95 err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
96 if err != nil {
97 return content
98 }
99 for _, detail := range reasoningData.ReasoningDetails {
100 switch detail.Type {
101 case "reasoning.text":
102 content = append(content, ai.ReasoningContent{
103 Text: detail.Text,
104 })
105 case "reasoning.summary":
106 content = append(content, ai.ReasoningContent{
107 Text: detail.Summary,
108 })
109 case "reasoning.encrypted":
110 content = append(content, ai.ReasoningContent{
111 Text: "[REDACTED]",
112 })
113 }
114 }
115 return content
116}
117
118func extractReasoningContext(ctx map[string]any) bool {
119 reasoningStarted, ok := ctx[reasoningStartedCtx]
120 if !ok {
121 return false
122 }
123 b, ok := reasoningStarted.(bool)
124 if !ok {
125 return false
126 }
127 return b
128}
129
130func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
131 if len(chunk.Choices) == 0 {
132 return ctx, true
133 }
134
135 reasoningStarted := extractReasoningContext(ctx)
136
137 for inx, choice := range chunk.Choices {
138 reasoningData := ReasoningData{}
139 err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
140 if err != nil {
141 yield(ai.StreamPart{
142 Type: ai.StreamPartTypeError,
143 Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
144 })
145 return ctx, false
146 }
147
148 emitEvent := func(reasoningContent string) bool {
149 if !reasoningStarted {
150 shouldContinue := yield(ai.StreamPart{
151 Type: ai.StreamPartTypeReasoningStart,
152 ID: fmt.Sprintf("%d", inx),
153 })
154 if !shouldContinue {
155 return false
156 }
157 }
158
159 return yield(ai.StreamPart{
160 Type: ai.StreamPartTypeReasoningDelta,
161 ID: fmt.Sprintf("%d", inx),
162 Delta: reasoningContent,
163 })
164 }
165 if len(reasoningData.ReasoningDetails) > 0 {
166 for _, detail := range reasoningData.ReasoningDetails {
167 if !reasoningStarted {
168 ctx[reasoningStartedCtx] = true
169 }
170 switch detail.Type {
171 case "reasoning.text":
172 return ctx, emitEvent(detail.Text)
173 case "reasoning.summary":
174 return ctx, emitEvent(detail.Summary)
175 case "reasoning.encrypted":
176 return ctx, emitEvent("[REDACTED]")
177 }
178 }
179 } else if reasoningData.Reasoning != "" {
180 return ctx, emitEvent(reasoningData.Reasoning)
181 }
182 if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
183 ctx[reasoningStartedCtx] = false
184 return ctx, yield(ai.StreamPart{
185 Type: ai.StreamPartTypeReasoningEnd,
186 ID: fmt.Sprintf("%d", inx),
187 })
188 }
189 }
190 return ctx, true
191}
192
193func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
194 if len(response.Choices) == 0 {
195 return ai.Usage{}, nil
196 }
197 openrouterUsage := UsageAccounting{}
198 usage := response.Usage
199
200 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
201
202 completionTokenDetails := usage.CompletionTokensDetails
203 promptTokenDetails := usage.PromptTokensDetails
204
205 var provider string
206 if p, ok := response.JSON.ExtraFields["provider"]; ok {
207 provider = p.Raw()
208 }
209
210 // Build provider metadata
211 providerMetadata := &ProviderMetadata{
212 Provider: provider,
213 Usage: openrouterUsage,
214 }
215
216 return ai.Usage{
217 InputTokens: usage.PromptTokens,
218 OutputTokens: usage.CompletionTokens,
219 TotalTokens: usage.TotalTokens,
220 ReasoningTokens: completionTokenDetails.ReasoningTokens,
221 CacheReadTokens: promptTokenDetails.CachedTokens,
222 }, providerMetadata
223}
224
225func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
226 usage := chunk.Usage
227 if usage.TotalTokens == 0 {
228 return ai.Usage{}, nil
229 }
230
231 streamProviderMetadata := &ProviderMetadata{}
232 if metadata != nil {
233 if providerMetadata, ok := metadata[Name]; ok {
234 converted, ok := providerMetadata.(*ProviderMetadata)
235 if ok {
236 streamProviderMetadata = converted
237 }
238 }
239 }
240 openrouterUsage := UsageAccounting{}
241 _ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
242 streamProviderMetadata.Usage = openrouterUsage
243
244 if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
245 streamProviderMetadata.Provider = p.Raw()
246 }
247
248 // we do this here because the acc does not add prompt details
249 completionTokenDetails := usage.CompletionTokensDetails
250 promptTokenDetails := usage.PromptTokensDetails
251 aiUsage := ai.Usage{
252 InputTokens: usage.PromptTokens,
253 OutputTokens: usage.CompletionTokens,
254 TotalTokens: usage.TotalTokens,
255 ReasoningTokens: completionTokenDetails.ReasoningTokens,
256 CacheReadTokens: promptTokenDetails.CachedTokens,
257 }
258
259 return aiUsage, ai.ProviderMetadata{
260 Name: streamProviderMetadata,
261 }
262}