language_model_hooks.go

  1package openrouter
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"maps"
  7
  8	"github.com/charmbracelet/fantasy/ai"
  9	"github.com/charmbracelet/fantasy/anthropic"
 10	openaisdk "github.com/openai/openai-go/v2"
 11	"github.com/openai/openai-go/v2/packages/param"
 12)
 13
 14const reasoningStartedCtx = "reasoning_started"
 15
 16func languagePrepareModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
 17	providerOptions := &ProviderOptions{}
 18	if v, ok := call.ProviderOptions[Name]; ok {
 19		providerOptions, ok = v.(*ProviderOptions)
 20		if !ok {
 21			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
 22		}
 23	}
 24
 25	extraFields := make(map[string]any)
 26
 27	if providerOptions.Provider != nil {
 28		data, err := structToMapJSON(providerOptions.Provider)
 29		if err != nil {
 30			return nil, err
 31		}
 32		extraFields["provider"] = data
 33	}
 34
 35	if providerOptions.Reasoning != nil {
 36		data, err := structToMapJSON(providerOptions.Reasoning)
 37		if err != nil {
 38			return nil, err
 39		}
 40		extraFields["reasoning"] = data
 41	}
 42
 43	if providerOptions.IncludeUsage != nil {
 44		extraFields["usage"] = map[string]any{
 45			"include": *providerOptions.IncludeUsage,
 46		}
 47	} else { // default include usage
 48		extraFields["usage"] = map[string]any{
 49			"include": true,
 50		}
 51	}
 52	if providerOptions.LogitBias != nil {
 53		params.LogitBias = providerOptions.LogitBias
 54	}
 55	if providerOptions.LogProbs != nil {
 56		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
 57	}
 58	if providerOptions.User != nil {
 59		params.User = param.NewOpt(*providerOptions.User)
 60	}
 61	if providerOptions.ParallelToolCalls != nil {
 62		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
 63	}
 64
 65	maps.Copy(extraFields, providerOptions.ExtraBody)
 66	params.SetExtraFields(extraFields)
 67	return nil, nil
 68}
 69
 70func languageModelMapFinishReason(choice openaisdk.ChatCompletionChoice) ai.FinishReason {
 71	finishReason := choice.FinishReason
 72	switch finishReason {
 73	case "stop":
 74		return ai.FinishReasonStop
 75	case "length":
 76		return ai.FinishReasonLength
 77	case "content_filter":
 78		return ai.FinishReasonContentFilter
 79	case "function_call", "tool_calls":
 80		return ai.FinishReasonToolCalls
 81	default:
 82		// for streaming responses the openai accumulator is not working as expected with some provider
 83		// therefore it is sending no finish reason so we need to manually handle it
 84		if len(choice.Message.ToolCalls) > 0 {
 85			return ai.FinishReasonToolCalls
 86		} else if finishReason == "" {
 87			return ai.FinishReasonStop
 88		}
 89		return ai.FinishReasonUnknown
 90	}
 91}
 92
 93func languageModelExtraContent(choice openaisdk.ChatCompletionChoice) []ai.Content {
 94	var content []ai.Content
 95	reasoningData := ReasoningData{}
 96	err := json.Unmarshal([]byte(choice.Message.RawJSON()), &reasoningData)
 97	if err != nil {
 98		return content
 99	}
100	for _, detail := range reasoningData.ReasoningDetails {
101		var metadata ai.ProviderMetadata
102
103		if detail.Signature != "" {
104			metadata = ai.ProviderMetadata{
105				Name: &ReasoningMetadata{
106					Signature: detail.Signature,
107				},
108				anthropic.Name: &anthropic.ReasoningOptionMetadata{
109					Signature: detail.Signature,
110				},
111			}
112		}
113		switch detail.Type {
114		case "reasoning.text":
115			content = append(content, ai.ReasoningContent{
116				Text:             detail.Text,
117				ProviderMetadata: metadata,
118			})
119		case "reasoning.summary":
120			content = append(content, ai.ReasoningContent{
121				Text:             detail.Summary,
122				ProviderMetadata: metadata,
123			})
124		case "reasoning.encrypted":
125			content = append(content, ai.ReasoningContent{
126				Text:             "[REDACTED]",
127				ProviderMetadata: metadata,
128			})
129		}
130	}
131	return content
132}
133
134func extractReasoningContext(ctx map[string]any) bool {
135	reasoningStarted, ok := ctx[reasoningStartedCtx]
136	if !ok {
137		return false
138	}
139	b, ok := reasoningStarted.(bool)
140	if !ok {
141		return false
142	}
143	return b
144}
145
146func languageModelStreamExtra(chunk openaisdk.ChatCompletionChunk, yield func(ai.StreamPart) bool, ctx map[string]any) (map[string]any, bool) {
147	if len(chunk.Choices) == 0 {
148		return ctx, true
149	}
150
151	reasoningStarted := extractReasoningContext(ctx)
152
153	for inx, choice := range chunk.Choices {
154		reasoningData := ReasoningData{}
155		err := json.Unmarshal([]byte(choice.Delta.RawJSON()), &reasoningData)
156		if err != nil {
157			yield(ai.StreamPart{
158				Type:  ai.StreamPartTypeError,
159				Error: ai.NewAIError("Unexpected", "error unmarshalling delta", err),
160			})
161			return ctx, false
162		}
163
164		emitEvent := func(reasoningContent string, signature string) bool {
165			if !reasoningStarted {
166				shouldContinue := yield(ai.StreamPart{
167					Type: ai.StreamPartTypeReasoningStart,
168					ID:   fmt.Sprintf("%d", inx),
169				})
170				if !shouldContinue {
171					return false
172				}
173			}
174
175			var metadata ai.ProviderMetadata
176
177			if signature != "" {
178				metadata = ai.ProviderMetadata{
179					Name: &ReasoningMetadata{
180						Signature: signature,
181					},
182					anthropic.Name: &anthropic.ReasoningOptionMetadata{
183						Signature: signature,
184					},
185				}
186			}
187
188			return yield(ai.StreamPart{
189				Type:             ai.StreamPartTypeReasoningDelta,
190				ID:               fmt.Sprintf("%d", inx),
191				Delta:            reasoningContent,
192				ProviderMetadata: metadata,
193			})
194		}
195		if len(reasoningData.ReasoningDetails) > 0 {
196			for _, detail := range reasoningData.ReasoningDetails {
197				if !reasoningStarted {
198					ctx[reasoningStartedCtx] = true
199				}
200				switch detail.Type {
201				case "reasoning.text":
202					return ctx, emitEvent(detail.Text, detail.Signature)
203				case "reasoning.summary":
204					return ctx, emitEvent(detail.Summary, detail.Signature)
205				case "reasoning.encrypted":
206					return ctx, emitEvent("[REDACTED]", detail.Signature)
207				}
208			}
209		} else if reasoningData.Reasoning != "" {
210			return ctx, emitEvent(reasoningData.Reasoning, "")
211		}
212		if reasoningStarted && (choice.Delta.Content != "" || len(choice.Delta.ToolCalls) > 0) {
213			ctx[reasoningStartedCtx] = false
214			return ctx, yield(ai.StreamPart{
215				Type: ai.StreamPartTypeReasoningEnd,
216				ID:   fmt.Sprintf("%d", inx),
217			})
218		}
219	}
220	return ctx, true
221}
222
223func languageModelUsage(response openaisdk.ChatCompletion) (ai.Usage, ai.ProviderOptionsData) {
224	if len(response.Choices) == 0 {
225		return ai.Usage{}, nil
226	}
227	openrouterUsage := UsageAccounting{}
228	usage := response.Usage
229
230	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
231
232	completionTokenDetails := usage.CompletionTokensDetails
233	promptTokenDetails := usage.PromptTokensDetails
234
235	var provider string
236	if p, ok := response.JSON.ExtraFields["provider"]; ok {
237		provider = p.Raw()
238	}
239
240	// Build provider metadata
241	providerMetadata := &ProviderMetadata{
242		Provider: provider,
243		Usage:    openrouterUsage,
244	}
245
246	return ai.Usage{
247		InputTokens:     usage.PromptTokens,
248		OutputTokens:    usage.CompletionTokens,
249		TotalTokens:     usage.TotalTokens,
250		ReasoningTokens: completionTokenDetails.ReasoningTokens,
251		CacheReadTokens: promptTokenDetails.CachedTokens,
252	}, providerMetadata
253}
254
255func languageModelStreamUsage(chunk openaisdk.ChatCompletionChunk, _ map[string]any, metadata ai.ProviderMetadata) (ai.Usage, ai.ProviderMetadata) {
256	usage := chunk.Usage
257	if usage.TotalTokens == 0 {
258		return ai.Usage{}, nil
259	}
260
261	streamProviderMetadata := &ProviderMetadata{}
262	if metadata != nil {
263		if providerMetadata, ok := metadata[Name]; ok {
264			converted, ok := providerMetadata.(*ProviderMetadata)
265			if ok {
266				streamProviderMetadata = converted
267			}
268		}
269	}
270	openrouterUsage := UsageAccounting{}
271	_ = json.Unmarshal([]byte(usage.RawJSON()), &openrouterUsage)
272	streamProviderMetadata.Usage = openrouterUsage
273
274	if p, ok := chunk.JSON.ExtraFields["provider"]; ok {
275		streamProviderMetadata.Provider = p.Raw()
276	}
277
278	// we do this here because the acc does not add prompt details
279	completionTokenDetails := usage.CompletionTokensDetails
280	promptTokenDetails := usage.PromptTokensDetails
281	aiUsage := ai.Usage{
282		InputTokens:     usage.PromptTokens,
283		OutputTokens:    usage.CompletionTokens,
284		TotalTokens:     usage.TotalTokens,
285		ReasoningTokens: completionTokenDetails.ReasoningTokens,
286		CacheReadTokens: promptTokenDetails.CachedTokens,
287	}
288
289	return aiUsage, ai.ProviderMetadata{
290		Name: streamProviderMetadata,
291	}
292}