language_model_hooks.go

  1package openrouter
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"maps"
  7
  8	"github.com/charmbracelet/fantasy/ai"
  9	openaisdk "github.com/openai/openai-go/v2"
 10	"github.com/openai/openai-go/v2/packages/param"
 11)
 12
 13const reasoningStartedCtx = "reasoning_started"
 14
 15func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 16	providerOptions := &ProviderOptions{}
 17	if v, ok := call.ProviderOptions[Name]; ok {
 18		providerOptions, ok = v.(*ProviderOptions)
 19		if !ok {
 20			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 21		}
 22	}
 23
 24	extraFields := make(map[string]any)
 25
 26	if providerOptions.Provider != nil {
 27		data, err := structToMapJSON(providerOptions.Provider)
 28		if err != nil {
 29			return nil, err
 30		}
 31		extraFields["provider"] = data
 32	}
 33
 34	if providerOptions.Reasoning != nil {
 35		data, err := structToMapJSON(providerOptions.Reasoning)
 36		if err != nil {
 37			return nil, err
 38		}
 39		extraFields["reasoning"] = data
 40	}
 41
 42	if providerOptions.IncludeUsage != nil {
 43		extraFields["usage"] = map[string]any{
 44			"include": *providerOptions.IncludeUsage,
 45		}
 46	} else { // default include usage
 47		extraFields["usage"] = map[string]any{
 48			"include": true,
 49		}
 50	}
 51	if providerOptions.LogitBias != nil {
 52		params.LogitBias = providerOptions.LogitBias
 53	}
 54	if providerOptions.LogProbs != nil {
 55		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 56	}
 57	if providerOptions.User != nil {
 58		params.User = param.NewOpt(*providerOptions.User)
 59	}
 60	if providerOptions.ParallelToolCalls != nil {
 61		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 62	}
 63
 64	maps.Copy(extraFields, providerOptions.ExtraBody)
 65	params.SetExtraFields(extraFields)
 66	return nil, nil
 67}
 68
 69func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
 70	finishReason := choice.FinishReason
 71	switch finishReason {
 72	case "stop":
 73		return ai.FinishReasonStop
 74	case "length":
 75		return ai.FinishReasonLength
 76	case "content_filter":
 77		return ai.FinishReasonContentFilter
 78	case "function_call", "tool_calls":
 79		return ai.FinishReasonToolCalls
 80	default:
 81		// for streaming responses the openai accumulator is not working as expected with some provider
 82		// therefore it is sending no finish reason so we need to manually handle it
 83		if len(choice.Message.ToolCalls) > 0 {
 84			return ai.FinishReasonToolCalls
 85		} else if finishReason == "" {
 86			return ai.FinishReasonStop
 87		}
 88		return ai.FinishReasonUnknown
 89	}
 90}
 91
 92func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
 93	var content []ai.Content
 94	reasoningData := ReasoningData{}
 95	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 96	if err != nil {
 97		return content
 98	}
 99	for _, detail := range reasoningData.ReasoningDetails {
100		switch detail.Type {
101		case "reasoning.text":
102			content = append(content, ai.ReasoningContent{
103				Text: detail.Text,
104			})
105		case "reasoning.summary":
106			content = append(content, ai.ReasoningContent{
107				Text: detail.Summary,
108			})
109		case "reasoning.encrypted":
110			content = append(content, ai.ReasoningContent{
111				Text: "[REDACTED]",
112			})
113		}
114	}
115	return content
116}
117
118func extractReasoningContext(ctx map[string]any) bool {
119	reasoningStarted, ok := ctx[reasoningStartedCtx]
120	if !ok {
121		return false
122	}
123	b, ok := reasoningStarted.(bool)
124	if !ok {
125		return false
126	}
127	return b
128}
129
130func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
131	if len(chunk.Choices) == 0 {
132		return ctx, true
133	}
134
135	reasoningStarted := extractReasoningContext(ctx)
136
137	for inx, choice := range chunk.Choices {
138		reasoningData := ReasoningData{}
139		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
140		if err != nil {
141			yield(ai.StreamPart{
142				Type:  ai.StreamPartTypeError,
143				Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
144			})
145			return ctx, false
146		}
147
148		emitEvent := func(reasoningContent string) bool {
149			if !reasoningStarted {
150				shouldContinue := yield(ai.StreamPart{
151					Type: ai.StreamPartTypeReasoningStart,
152					ID:   fmt.Sprintf("%d", inx),
153				})
154				if !shouldContinue {
155					return false
156				}
157			}
158
159			return yield(ai.StreamPart{
160				Type:  ai.StreamPartTypeReasoningDelta,
161				ID:    fmt.Sprintf("%d", inx),
162				Delta: reasoningContent,
163			})
164		}
165		if len(reasoningData.ReasoningDetails) > 0 {
166			for _, detail := range reasoningData.ReasoningDetails {
167				if !reasoningStarted {
168					ctx[reasoningStartedCtx] = true
169				}
170				switch detail.Type {
171				case "reasoning.text":
172					return ctx, emitEvent(detail.Text)
173				case "reasoning.summary":
174					return ctx, emitEvent(detail.Summary)
175				case "reasoning.encrypted":
176					return ctx, emitEvent("[REDACTED]")
177				}
178			}
179		} else if reasoningData.Reasoning != "" {
180			return ctx, emitEvent(reasoningData.Reasoning)
181		}
182		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
183			ctx[reasoningStartedCtx] = false
184			return ctx, yield(ai.StreamPart{
185				Type: ai.StreamPartTypeReasoningEnd,
186				ID:   fmt.Sprintf("%d", inx),
187			})
188		}
189	}
190	return ctx, true
191}
192
193func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
194	if len(response.Choices) == 0 {
195		return ai.Usage{}, nil
196	}
197	openrouterUsage := UsageAccounting{}
198	usage := response.Usage
199
200	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
201
202	completionTokenDetails := usage.CompletionTokensDetails
203	promptTokenDetails := usage.PromptTokensDetails
204
205	var provider string
206	if p, ok := response.JSON.ExtraFields["provider"]; ok {
207		provider = p.Raw()
208	}
209
210	// Build provider metadata
211	providerMetadata := &ProviderMetadata{
212		Provider: provider,
213		Usage:    openrouterUsage,
214	}
215
216	return ai.Usage{
217		InputTokens:     usage.PromptTokens,
218		OutputTokens:    usage.CompletionTokens,
219		TotalTokens:     usage.TotalTokens,
220		ReasoningTokens: completionTokenDetails.ReasoningTokens,
221		CacheReadTokens: promptTokenDetails.CachedTokens,
222	}, providerMetadata
223}
224
225func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
226	usage := chunk.Usage
227	if usage.TotalTokens == 0 {
228		return ai.Usage{}, nil
229	}
230
231	streamProviderMetadata := &ProviderMetadata{}
232	if metadata != nil {
233		if providerMetadata, ok := metadata[Name]; ok {
234			converted, ok := providerMetadata.(*ProviderMetadata)
235			if ok {
236				streamProviderMetadata = converted
237			}
238		}
239	}
240	openrouterUsage := UsageAccounting{}
241	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
242	streamProviderMetadata.Usage = openrouterUsage
243
244	if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
245		streamProviderMetadata.Provider = p.Raw()
246	}
247
248	// we do this here because the acc does not add prompt details
249	completionTokenDetails := usage.CompletionTokensDetails
250	promptTokenDetails := usage.PromptTokensDetails
251	aiUsage := ai.Usage{
252		InputTokens:     usage.PromptTokens,
253		OutputTokens:    usage.CompletionTokens,
254		TotalTokens:     usage.TotalTokens,
255		ReasoningTokens: completionTokenDetails.ReasoningTokens,
256		CacheReadTokens: promptTokenDetails.CachedTokens,
257	}
258
259	return aiUsage, ai.ProviderMetadata{
260		Name: streamProviderMetadata,
261	}
262}