language_model_hooks.go

  1package openrouter
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"maps"
  7
  8	"github.com/charmbracelet/fantasy/ai"
  9	"github.com/charmbracelet/fantasy/anthropic"
 10	openaisdk "github.com/openai/openai-go/v2"
 11	"github.com/openai/openai-go/v2/packages/param"
 12)
 13
 14const reasoningStartedCtx = "reasoning_started"
 15
 16func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 17	providerOptions := &ProviderOptions{}
 18	if v, ok := call.ProviderOptions[Name]; ok {
 19		providerOptions, ok = v.(*ProviderOptions)
 20		if !ok {
 21			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 22		}
 23	}
 24
 25	extraFields := make(map[string]any)
 26
 27	if providerOptions.Provider != nil {
 28		data, err := structToMapJSON(providerOptions.Provider)
 29		if err != nil {
 30			return nil, err
 31		}
 32		extraFields["provider"] = data
 33	}
 34
 35	if providerOptions.Reasoning != nil {
 36		data, err := structToMapJSON(providerOptions.Reasoning)
 37		if err != nil {
 38			return nil, err
 39		}
 40		extraFields["reasoning"] = data
 41	}
 42
 43	if providerOptions.IncludeUsage != nil {
 44		extraFields["usage"] = map[string]any{
 45			"include": *providerOptions.IncludeUsage,
 46		}
 47	} else { // default include usage
 48		extraFields["usage"] = map[string]any{
 49			"include": true,
 50		}
 51	}
 52	if providerOptions.LogitBias != nil {
 53		params.LogitBias = providerOptions.LogitBias
 54	}
 55	if providerOptions.LogProbs != nil {
 56		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 57	}
 58	if providerOptions.User != nil {
 59		params.User = param.NewOpt(*providerOptions.User)
 60	}
 61	if providerOptions.ParallelToolCalls != nil {
 62		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 63	}
 64
 65	maps.Copy(extraFields, providerOptions.ExtraBody)
 66	params.SetExtraFields(extraFields)
 67	return nil, nil
 68}
 69
 70func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
 71	finishReason := choice.FinishReason
 72	switch finishReason {
 73	case "stop":
 74		return ai.FinishReasonStop
 75	case "length":
 76		return ai.FinishReasonLength
 77	case "content_filter":
 78		return ai.FinishReasonContentFilter
 79	case "function_call", "tool_calls":
 80		return ai.FinishReasonToolCalls
 81	default:
 82		// for streaming responses the openai accumulator is not working as expected with some provider
 83		// therefore it is sending no finish reason so we need to manually handle it
 84		if len(choice.Message.ToolCalls) > 0 {
 85			return ai.FinishReasonToolCalls
 86		} else if finishReason == "" {
 87			return ai.FinishReasonStop
 88		}
 89		return ai.FinishReasonUnknown
 90	}
 91}
 92
 93func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
 94	var content []ai.Content
 95	reasoningData := ReasoningData{}
 96	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 97	if err != nil {
 98		return content
 99	}
100	for _, detail := range reasoningData.ReasoningDetails {
101
102		var metadata ai.ProviderMetadata
103
104		if detail.Signature != "" {
105			metadata = ai.ProviderMetadata{
106				Name: &ReasoningMetadata{
107					Signature: detail.Signature,
108				},
109				anthropic.Name: &anthropic.ReasoningOptionMetadata{
110					Signature: detail.Signature,
111				},
112			}
113		}
114		switch detail.Type {
115		case "reasoning.text":
116			content = append(content, ai.ReasoningContent{
117				Text:             detail.Text,
118				ProviderMetadata: metadata,
119			})
120		case "reasoning.summary":
121			content = append(content, ai.ReasoningContent{
122				Text:             detail.Summary,
123				ProviderMetadata: metadata,
124			})
125		case "reasoning.encrypted":
126			content = append(content, ai.ReasoningContent{
127				Text:             "[REDACTED]",
128				ProviderMetadata: metadata,
129			})
130		}
131	}
132	return content
133}
134
135func extractReasoningContext(ctx map[string]any) bool {
136	reasoningStarted, ok := ctx[reasoningStartedCtx]
137	if !ok {
138		return false
139	}
140	b, ok := reasoningStarted.(bool)
141	if !ok {
142		return false
143	}
144	return b
145}
146
147func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
148	if len(chunk.Choices) == 0 {
149		return ctx, true
150	}
151
152	reasoningStarted := extractReasoningContext(ctx)
153
154	for inx, choice := range chunk.Choices {
155		reasoningData := ReasoningData{}
156		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
157		if err != nil {
158			yield(ai.StreamPart{
159				Type:  ai.StreamPartTypeError,
160				Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
161			})
162			return ctx, false
163		}
164
165		emitEvent := func(reasoningContent string, signature string) bool {
166			if !reasoningStarted {
167				shouldContinue := yield(ai.StreamPart{
168					Type: ai.StreamPartTypeReasoningStart,
169					ID:   fmt.Sprintf("%d", inx),
170				})
171				if !shouldContinue {
172					return false
173				}
174			}
175
176			var metadata ai.ProviderMetadata
177
178			if signature != "" {
179				metadata = ai.ProviderMetadata{
180					Name: &ReasoningMetadata{
181						Signature: signature,
182					},
183					anthropic.Name: &anthropic.ReasoningOptionMetadata{
184						Signature: signature,
185					},
186				}
187			}
188
189			return yield(ai.StreamPart{
190				Type:             ai.StreamPartTypeReasoningDelta,
191				ID:               fmt.Sprintf("%d", inx),
192				Delta:            reasoningContent,
193				ProviderMetadata: metadata,
194			})
195		}
196		if len(reasoningData.ReasoningDetails) > 0 {
197			for _, detail := range reasoningData.ReasoningDetails {
198				if !reasoningStarted {
199					ctx[reasoningStartedCtx] = true
200				}
201				switch detail.Type {
202				case "reasoning.text":
203					return ctx, emitEvent(detail.Text, detail.Signature)
204				case "reasoning.summary":
205					return ctx, emitEvent(detail.Summary, detail.Signature)
206				case "reasoning.encrypted":
207					return ctx, emitEvent("[REDACTED]", detail.Signature)
208				}
209			}
210		} else if reasoningData.Reasoning != "" {
211			return ctx, emitEvent(reasoningData.Reasoning, "")
212		}
213		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
214			ctx[reasoningStartedCtx] = false
215			return ctx, yield(ai.StreamPart{
216				Type: ai.StreamPartTypeReasoningEnd,
217				ID:   fmt.Sprintf("%d", inx),
218			})
219		}
220	}
221	return ctx, true
222}
223
224func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
225	if len(response.Choices) == 0 {
226		return ai.Usage{}, nil
227	}
228	openrouterUsage := UsageAccounting{}
229	usage := response.Usage
230
231	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
232
233	completionTokenDetails := usage.CompletionTokensDetails
234	promptTokenDetails := usage.PromptTokensDetails
235
236	var provider string
237	if p, ok := response.JSON.ExtraFields["provider"]; ok {
238		provider = p.Raw()
239	}
240
241	// Build provider metadata
242	providerMetadata := &ProviderMetadata{
243		Provider: provider,
244		Usage:    openrouterUsage,
245	}
246
247	return ai.Usage{
248		InputTokens:     usage.PromptTokens,
249		OutputTokens:    usage.CompletionTokens,
250		TotalTokens:     usage.TotalTokens,
251		ReasoningTokens: completionTokenDetails.ReasoningTokens,
252		CacheReadTokens: promptTokenDetails.CachedTokens,
253	}, providerMetadata
254}
255
256func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
257	usage := chunk.Usage
258	if usage.TotalTokens == 0 {
259		return ai.Usage{}, nil
260	}
261
262	streamProviderMetadata := &ProviderMetadata{}
263	if metadata != nil {
264		if providerMetadata, ok := metadata[Name]; ok {
265			converted, ok := providerMetadata.(*ProviderMetadata)
266			if ok {
267				streamProviderMetadata = converted
268			}
269		}
270	}
271	openrouterUsage := UsageAccounting{}
272	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
273	streamProviderMetadata.Usage = openrouterUsage
274
275	if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
276		streamProviderMetadata.Provider = p.Raw()
277	}
278
279	// we do this here because the acc does not add prompt details
280	completionTokenDetails := usage.CompletionTokensDetails
281	promptTokenDetails := usage.PromptTokensDetails
282	aiUsage := ai.Usage{
283		InputTokens:     usage.PromptTokens,
284		OutputTokens:    usage.CompletionTokens,
285		TotalTokens:     usage.TotalTokens,
286		ReasoningTokens: completionTokenDetails.ReasoningTokens,
287		CacheReadTokens: promptTokenDetails.CachedTokens,
288	}
289
290	return aiUsage, ai.ProviderMetadata{
291		Name: streamProviderMetadata,
292	}
293}