1// Package main provides a command-line tool to fetch models from OpenRouter
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "net/http"
12 "os"
13 "slices"
14 "strconv"
15 "time"
16
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18)
19
20// Model represents the complete model configuration.
21type Model struct {
22 ID string `json:"id"`
23 CanonicalSlug string `json:"canonical_slug"`
24 HuggingFaceID string `json:"hugging_face_id"`
25 Name string `json:"name"`
26 Created int64 `json:"created"`
27 Description string `json:"description"`
28 ContextLength int64 `json:"context_length"`
29 Architecture Architecture `json:"architecture"`
30 Pricing Pricing `json:"pricing"`
31 TopProvider TopProvider `json:"top_provider"`
32 SupportedParams []string `json:"supported_parameters"`
33}
34
35// Architecture defines the model's architecture details.
36type Architecture struct {
37 Modality string `json:"modality"`
38 InputModalities []string `json:"input_modalities"`
39 OutputModalities []string `json:"output_modalities"`
40 Tokenizer string `json:"tokenizer"`
41 InstructType *string `json:"instruct_type"`
42}
43
44// Pricing contains the pricing information for different operations.
45type Pricing struct {
46 Prompt string `json:"prompt"`
47 Completion string `json:"completion"`
48 Request string `json:"request"`
49 Image string `json:"image"`
50 WebSearch string `json:"web_search"`
51 InternalReasoning string `json:"internal_reasoning"`
52 InputCacheRead string `json:"input_cache_read"`
53 InputCacheWrite string `json:"input_cache_write"`
54}
55
56// TopProvider describes the top provider's capabilities.
57type TopProvider struct {
58 ContextLength int64 `json:"context_length"`
59 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
60 IsModerated bool `json:"is_moderated"`
61}
62
63// Endpoint represents a single endpoint configuration for a model.
64type Endpoint struct {
65 Name string `json:"name"`
66 ContextLength int64 `json:"context_length"`
67 Pricing Pricing `json:"pricing"`
68 ProviderName string `json:"provider_name"`
69 Tag string `json:"tag"`
70 Quantization *string `json:"quantization"`
71 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
72 MaxPromptTokens *int64 `json:"max_prompt_tokens"`
73 SupportedParams []string `json:"supported_parameters"`
74 Status int `json:"status"`
75 UptimeLast30m float64 `json:"uptime_last_30m"`
76}
77
78// EndpointsResponse is the response structure for the endpoints API.
79type EndpointsResponse struct {
80 Data struct {
81 ID string `json:"id"`
82 Name string `json:"name"`
83 Created int64 `json:"created"`
84 Description string `json:"description"`
85 Endpoints []Endpoint `json:"endpoints"`
86 } `json:"data"`
87}
88
89// ModelsResponse is the response structure for the models API.
90type ModelsResponse struct {
91 Data []Model `json:"data"`
92}
93
94// ModelPricing is the pricing structure for a model, detailing costs per
95// million tokens for input and output, both cached and uncached.
96type ModelPricing struct {
97 CostPer1MIn float64 `json:"cost_per_1m_in"`
98 CostPer1MOut float64 `json:"cost_per_1m_out"`
99 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
100 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
101}
102
103func getPricing(model Model) ModelPricing {
104 pricing := ModelPricing{}
105 costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
106 if err != nil {
107 costPrompt = 0.0
108 }
109 pricing.CostPer1MIn = costPrompt * 1_000_000
110 costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
111 if err != nil {
112 costCompletion = 0.0
113 }
114 pricing.CostPer1MOut = costCompletion * 1_000_000
115
116 costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
117 if err != nil {
118 costPromptCached = 0.0
119 }
120 pricing.CostPer1MInCached = costPromptCached * 1_000_000
121 costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
122 if err != nil {
123 costCompletionCached = 0.0
124 }
125 pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
126 return pricing
127}
128
129func fetchOpenRouterModels() (*ModelsResponse, error) {
130 client := &http.Client{Timeout: 30 * time.Second}
131 req, _ := http.NewRequestWithContext(
132 context.Background(),
133 "GET",
134 "https://openrouter.ai/api/v1/models",
135 nil,
136 )
137 req.Header.Set("User-Agent", "Crush-Client/1.0")
138 resp, err := client.Do(req)
139 if err != nil {
140 return nil, err //nolint:wrapcheck
141 }
142 defer resp.Body.Close() //nolint:errcheck
143 if resp.StatusCode != 200 {
144 body, _ := io.ReadAll(resp.Body)
145 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
146 }
147 var mr ModelsResponse
148 if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
149 return nil, err //nolint:wrapcheck
150 }
151 return &mr, nil
152}
153
154func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
155 client := &http.Client{Timeout: 30 * time.Second}
156 url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
157 req, _ := http.NewRequestWithContext(
158 context.Background(),
159 "GET",
160 url,
161 nil,
162 )
163 req.Header.Set("User-Agent", "Crush-Client/1.0")
164 resp, err := client.Do(req)
165 if err != nil {
166 return nil, err //nolint:wrapcheck
167 }
168 defer resp.Body.Close() //nolint:errcheck
169 if resp.StatusCode != 200 {
170 body, _ := io.ReadAll(resp.Body)
171 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
172 }
173 var er EndpointsResponse
174 if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
175 return nil, err //nolint:wrapcheck
176 }
177 return &er, nil
178}
179
180func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
181 if len(endpoints) == 0 {
182 return nil
183 }
184
185 var best *Endpoint
186 for i := range endpoints {
187 endpoint := &endpoints[i]
188 // Skip endpoints with poor status or uptime
189 if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
190 continue
191 }
192
193 if best == nil {
194 best = endpoint
195 continue
196 }
197
198 // Prefer higher context length
199 if endpoint.ContextLength > best.ContextLength {
200 best = endpoint
201 } else if endpoint.ContextLength == best.ContextLength {
202 // If context length is the same, prefer better uptime
203 if endpoint.UptimeLast30m > best.UptimeLast30m {
204 best = endpoint
205 }
206 }
207 }
208
209 // If no good endpoint found, return the first one as fallback
210 if best == nil {
211 best = &endpoints[0]
212 }
213
214 return best
215}
216
217// This is used to generate the openrouter.json config file.
218func main() {
219 modelsResp, err := fetchOpenRouterModels()
220 if err != nil {
221 log.Fatal("Error fetching OpenRouter models:", err)
222 }
223
224 openRouterProvider := catwalk.Provider{
225 Name: "OpenRouter",
226 ID: "openrouter",
227 APIKey: "$OPENROUTER_API_KEY",
228 APIEndpoint: "https://openrouter.ai/api/v1",
229 Type: catwalk.TypeOpenAI,
230 DefaultLargeModelID: "anthropic/claude-sonnet-4",
231 DefaultSmallModelID: "anthropic/claude-3.5-haiku",
232 Models: []catwalk.Model{},
233 }
234
235 for _, model := range modelsResp.Data {
236 // skip nonβtext models or those without tools
237 if !slices.Contains(model.SupportedParams, "tools") ||
238 !slices.Contains(model.Architecture.InputModalities, "text") ||
239 !slices.Contains(model.Architecture.OutputModalities, "text") {
240 continue
241 }
242
243 // Fetch endpoints for this model to get the best configuration
244 endpointsResp, err := fetchModelEndpoints(model.ID)
245 if err != nil {
246 fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
247 // Fall back to using the original model data
248 pricing := getPricing(model)
249 canReason := slices.Contains(model.SupportedParams, "reasoning")
250 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
251
252 m := catwalk.Model{
253 ID: model.ID,
254 Name: model.Name,
255 CostPer1MIn: pricing.CostPer1MIn,
256 CostPer1MOut: pricing.CostPer1MOut,
257 CostPer1MInCached: pricing.CostPer1MInCached,
258 CostPer1MOutCached: pricing.CostPer1MOutCached,
259 ContextWindow: model.ContextLength,
260 CanReason: canReason,
261 SupportsImages: supportsImages,
262 }
263 if model.TopProvider.MaxCompletionTokens != nil {
264 m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
265 } else {
266 m.DefaultMaxTokens = model.ContextLength / 10
267 }
268 if model.TopProvider.ContextLength > 0 {
269 m.ContextWindow = model.TopProvider.ContextLength
270 }
271 openRouterProvider.Models = append(openRouterProvider.Models, m)
272 continue
273 }
274
275 // Select the best endpoint
276 bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
277 if bestEndpoint == nil {
278 fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
279 continue
280 }
281
282 // Check if the best endpoint supports tools
283 if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
284 continue
285 }
286
287 // Use the best endpoint's configuration
288 pricing := ModelPricing{}
289 costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
290 if err != nil {
291 costPrompt = 0.0
292 }
293 pricing.CostPer1MIn = costPrompt * 1_000_000
294 costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
295 if err != nil {
296 costCompletion = 0.0
297 }
298 pricing.CostPer1MOut = costCompletion * 1_000_000
299
300 costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
301 if err != nil {
302 costPromptCached = 0.0
303 }
304 pricing.CostPer1MInCached = costPromptCached * 1_000_000
305 costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
306 if err != nil {
307 costCompletionCached = 0.0
308 }
309 pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
310
311 canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
312 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
313
314 m := catwalk.Model{
315 ID: model.ID,
316 Name: model.Name,
317 CostPer1MIn: pricing.CostPer1MIn,
318 CostPer1MOut: pricing.CostPer1MOut,
319 CostPer1MInCached: pricing.CostPer1MInCached,
320 CostPer1MOutCached: pricing.CostPer1MOutCached,
321 ContextWindow: bestEndpoint.ContextLength,
322 CanReason: canReason,
323 SupportsImages: supportsImages,
324 }
325
326 // Set max tokens based on the best endpoint
327 if bestEndpoint.MaxCompletionTokens != nil {
328 m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
329 } else {
330 m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
331 }
332
333 openRouterProvider.Models = append(openRouterProvider.Models, m)
334 fmt.Printf("Added model %s with context window %d from provider %s\n",
335 model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
336 }
337
338 // save the json in internal/providers/config/openrouter.json
339 data, err := json.MarshalIndent(openRouterProvider, "", " ")
340 if err != nil {
341 log.Fatal("Error marshaling OpenRouter provider:", err)
342 }
343 // write to file
344 if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
345 log.Fatal("Error writing OpenRouter provider config:", err)
346 }
347}