language_model_hooks.go

 1package openrouter
 2
 3import (
 4	"maps"
 5
 6	"github.com/charmbracelet/fantasy/ai"
 7	openaisdk "github.com/openai/openai-go/v2"
 8	"github.com/openai/openai-go/v2/packages/param"
 9)
10
11func prepareLanguageModelCall(model ai.LanguageModel, params *openaisdk.ChatCompletionNewParams, call ai.Call) ([]ai.CallWarning, error) {
12	providerOptions := &ProviderOptions{}
13	if v, ok := call.ProviderOptions[Name]; ok {
14		providerOptions, ok = v.(*ProviderOptions)
15		if !ok {
16			return nil, ai.NewInvalidArgumentError("providerOptions", "openrouter provider options should be *openrouter.ProviderOptions", nil)
17		}
18	}
19
20	extraFields := make(map[string]any)
21
22	if providerOptions.Provider != nil {
23		data, err := structToMapJSON(providerOptions.Provider)
24		if err != nil {
25			return nil, err
26		}
27		extraFields["provider"] = data
28	}
29
30	if providerOptions.Reasoning != nil {
31		data, err := structToMapJSON(providerOptions.Reasoning)
32		if err != nil {
33			return nil, err
34		}
35		extraFields["reasoning"] = data
36	}
37
38	if providerOptions.IncludeUsage != nil {
39		extraFields["usage"] = map[string]any{
40			"include": *providerOptions.IncludeUsage,
41		}
42	}
43	if providerOptions.LogitBias != nil {
44		params.LogitBias = providerOptions.LogitBias
45	}
46	if providerOptions.LogProbs != nil {
47		params.Logprobs = param.NewOpt(*providerOptions.LogProbs)
48	}
49	if providerOptions.User != nil {
50		params.User = param.NewOpt(*providerOptions.User)
51	}
52	if providerOptions.ParallelToolCalls != nil {
53		params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls)
54	}
55
56	maps.Copy(extraFields, providerOptions.ExtraBody)
57	params.SetExtraFields(extraFields)
58	return nil, nil
59}