1// Package main provides a command-line tool to fetch models from OpenRouter
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "net/http"
12 "os"
13 "slices"
14 "strconv"
15 "time"
16
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18)
19
20// Model represents the complete model configuration.
21type Model struct {
22 ID string `json:"id"`
23 CanonicalSlug string `json:"canonical_slug"`
24 HuggingFaceID string `json:"hugging_face_id"`
25 Name string `json:"name"`
26 Created int64 `json:"created"`
27 Description string `json:"description"`
28 ContextLength int64 `json:"context_length"`
29 Architecture Architecture `json:"architecture"`
30 Pricing Pricing `json:"pricing"`
31 TopProvider TopProvider `json:"top_provider"`
32 SupportedParams []string `json:"supported_parameters"`
33}
34
35// Architecture defines the model's architecture details.
36type Architecture struct {
37 Modality string `json:"modality"`
38 InputModalities []string `json:"input_modalities"`
39 OutputModalities []string `json:"output_modalities"`
40 Tokenizer string `json:"tokenizer"`
41 InstructType *string `json:"instruct_type"`
42}
43
44// Pricing contains the pricing information for different operations.
45type Pricing struct {
46 Prompt string `json:"prompt"`
47 Completion string `json:"completion"`
48 Request string `json:"request"`
49 Image string `json:"image"`
50 WebSearch string `json:"web_search"`
51 InternalReasoning string `json:"internal_reasoning"`
52 InputCacheRead string `json:"input_cache_read"`
53 InputCacheWrite string `json:"input_cache_write"`
54}
55
56// TopProvider describes the top provider's capabilities.
57type TopProvider struct {
58 ContextLength int64 `json:"context_length"`
59 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
60 IsModerated bool `json:"is_moderated"`
61}
62
63// Endpoint represents a single endpoint configuration for a model.
64type Endpoint struct {
65 Name string `json:"name"`
66 ContextLength int64 `json:"context_length"`
67 Pricing Pricing `json:"pricing"`
68 ProviderName string `json:"provider_name"`
69 Tag string `json:"tag"`
70 Quantization *string `json:"quantization"`
71 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
72 MaxPromptTokens *int64 `json:"max_prompt_tokens"`
73 SupportedParams []string `json:"supported_parameters"`
74 Status int `json:"status"`
75 UptimeLast30m float64 `json:"uptime_last_30m"`
76}
77
78// EndpointsResponse is the response structure for the endpoints API.
79type EndpointsResponse struct {
80 Data struct {
81 ID string `json:"id"`
82 Name string `json:"name"`
83 Created int64 `json:"created"`
84 Description string `json:"description"`
85 Endpoints []Endpoint `json:"endpoints"`
86 } `json:"data"`
87}
88
89// ModelsResponse is the response structure for the models API.
90type ModelsResponse struct {
91 Data []Model `json:"data"`
92}
93
94// ModelPricing is the pricing structure for a model, detailing costs per
95// million tokens for input and output, both cached and uncached.
96type ModelPricing struct {
97 CostPer1MIn float64 `json:"cost_per_1m_in"`
98 CostPer1MOut float64 `json:"cost_per_1m_out"`
99 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
100 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
101}
102
103func getPricing(model Model) ModelPricing {
104 pricing := ModelPricing{}
105 costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
106 if err != nil {
107 costPrompt = 0.0
108 }
109 pricing.CostPer1MIn = costPrompt * 1_000_000
110 costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
111 if err != nil {
112 costCompletion = 0.0
113 }
114 pricing.CostPer1MOut = costCompletion * 1_000_000
115
116 costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
117 if err != nil {
118 costPromptCached = 0.0
119 }
120 pricing.CostPer1MInCached = costPromptCached * 1_000_000
121 costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
122 if err != nil {
123 costCompletionCached = 0.0
124 }
125 pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
126 return pricing
127}
128
129func fetchOpenRouterModels() (*ModelsResponse, error) {
130 client := &http.Client{Timeout: 30 * time.Second}
131 req, _ := http.NewRequestWithContext(
132 context.Background(),
133 "GET",
134 "https://openrouter.ai/api/v1/models",
135 nil,
136 )
137 req.Header.Set("User-Agent", "Crush-Client/1.0")
138 resp, err := client.Do(req)
139 if err != nil {
140 return nil, err //nolint:wrapcheck
141 }
142 defer resp.Body.Close() //nolint:errcheck
143 if resp.StatusCode != 200 {
144 body, _ := io.ReadAll(resp.Body)
145 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
146 }
147 var mr ModelsResponse
148 if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
149 return nil, err //nolint:wrapcheck
150 }
151 return &mr, nil
152}
153
154func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
155 client := &http.Client{Timeout: 30 * time.Second}
156 url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
157 req, _ := http.NewRequestWithContext(
158 context.Background(),
159 "GET",
160 url,
161 nil,
162 )
163 req.Header.Set("User-Agent", "Crush-Client/1.0")
164 resp, err := client.Do(req)
165 if err != nil {
166 return nil, err //nolint:wrapcheck
167 }
168 defer resp.Body.Close() //nolint:errcheck
169 if resp.StatusCode != 200 {
170 body, _ := io.ReadAll(resp.Body)
171 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
172 }
173 var er EndpointsResponse
174 if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
175 return nil, err //nolint:wrapcheck
176 }
177 return &er, nil
178}
179
180func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
181 if len(endpoints) == 0 {
182 return nil
183 }
184
185 var best *Endpoint
186 for i := range endpoints {
187 endpoint := &endpoints[i]
188 // Skip endpoints with poor status or uptime
189 if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
190 continue
191 }
192
193 if best == nil {
194 best = endpoint
195 continue
196 }
197
198 if isBetterEndpoint(endpoint, best) {
199 best = endpoint
200 }
201 }
202
203 // If no good endpoint found, return the first one as fallback
204 if best == nil {
205 best = &endpoints[0]
206 }
207
208 return best
209}
210
211func isBetterEndpoint(candidate, current *Endpoint) bool {
212 candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
213 currentHasTools := slices.Contains(current.SupportedParams, "tools")
214
215 // Prefer endpoints with tool support over those without
216 if candidateHasTools && !currentHasTools {
217 return true
218 }
219 if !candidateHasTools && currentHasTools {
220 return false
221 }
222
223 // Both have same tool support status, compare other factors
224 if candidate.ContextLength > current.ContextLength {
225 return true
226 }
227 if candidate.ContextLength == current.ContextLength {
228 return candidate.UptimeLast30m > current.UptimeLast30m
229 }
230
231 return false
232}
233
234// This is used to generate the openrouter.json config file.
235func main() {
236 modelsResp, err := fetchOpenRouterModels()
237 if err != nil {
238 log.Fatal("Error fetching OpenRouter models:", err)
239 }
240
241 openRouterProvider := catwalk.Provider{
242 Name: "OpenRouter",
243 ID: "openrouter",
244 APIKey: "$OPENROUTER_API_KEY",
245 APIEndpoint: "https://openrouter.ai/api/v1",
246 Type: catwalk.TypeOpenAI,
247 DefaultLargeModelID: "anthropic/claude-sonnet-4",
248 DefaultSmallModelID: "anthropic/claude-3.5-haiku",
249 Models: []catwalk.Model{},
250 }
251
252 for _, model := range modelsResp.Data {
253 // skip nonβtext models or those without tools
254 if !slices.Contains(model.SupportedParams, "tools") ||
255 !slices.Contains(model.Architecture.InputModalities, "text") ||
256 !slices.Contains(model.Architecture.OutputModalities, "text") {
257 continue
258 }
259
260 // Fetch endpoints for this model to get the best configuration
261 endpointsResp, err := fetchModelEndpoints(model.ID)
262 if err != nil {
263 fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
264 // Fall back to using the original model data
265 pricing := getPricing(model)
266 canReason := slices.Contains(model.SupportedParams, "reasoning")
267 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
268
269 m := catwalk.Model{
270 ID: model.ID,
271 Name: model.Name,
272 CostPer1MIn: pricing.CostPer1MIn,
273 CostPer1MOut: pricing.CostPer1MOut,
274 CostPer1MInCached: pricing.CostPer1MInCached,
275 CostPer1MOutCached: pricing.CostPer1MOutCached,
276 ContextWindow: model.ContextLength,
277 CanReason: canReason,
278 SupportsImages: supportsImages,
279 }
280 if model.TopProvider.MaxCompletionTokens != nil {
281 m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
282 } else {
283 m.DefaultMaxTokens = model.ContextLength / 10
284 }
285 if model.TopProvider.ContextLength > 0 {
286 m.ContextWindow = model.TopProvider.ContextLength
287 }
288 openRouterProvider.Models = append(openRouterProvider.Models, m)
289 continue
290 }
291
292 // Select the best endpoint
293 bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
294 if bestEndpoint == nil {
295 fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
296 continue
297 }
298
299 // Check if the best endpoint supports tools
300 if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
301 continue
302 }
303
304 // Use the best endpoint's configuration
305 pricing := ModelPricing{}
306 costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
307 if err != nil {
308 costPrompt = 0.0
309 }
310 pricing.CostPer1MIn = costPrompt * 1_000_000
311 costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
312 if err != nil {
313 costCompletion = 0.0
314 }
315 pricing.CostPer1MOut = costCompletion * 1_000_000
316
317 costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
318 if err != nil {
319 costPromptCached = 0.0
320 }
321 pricing.CostPer1MInCached = costPromptCached * 1_000_000
322 costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
323 if err != nil {
324 costCompletionCached = 0.0
325 }
326 pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
327
328 canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
329 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
330
331 m := catwalk.Model{
332 ID: model.ID,
333 Name: model.Name,
334 CostPer1MIn: pricing.CostPer1MIn,
335 CostPer1MOut: pricing.CostPer1MOut,
336 CostPer1MInCached: pricing.CostPer1MInCached,
337 CostPer1MOutCached: pricing.CostPer1MOutCached,
338 ContextWindow: bestEndpoint.ContextLength,
339 CanReason: canReason,
340 SupportsImages: supportsImages,
341 }
342
343 // Set max tokens based on the best endpoint
344 if bestEndpoint.MaxCompletionTokens != nil {
345 m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
346 } else {
347 m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
348 }
349
350 openRouterProvider.Models = append(openRouterProvider.Models, m)
351 fmt.Printf("Added model %s with context window %d from provider %s\n",
352 model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
353 }
354
355 // save the json in internal/providers/config/openrouter.json
356 data, err := json.MarshalIndent(openRouterProvider, "", " ")
357 if err != nil {
358 log.Fatal("Error marshaling OpenRouter provider:", err)
359 }
360 // write to file
361 if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
362 log.Fatal("Error writing OpenRouter provider config:", err)
363 }
364}