1// Package main provides a command-line tool to fetch models from OpenRouter
2// and generate a configuration file for the provider.
3package main
4
5import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log"
11 "math"
12 "net/http"
13 "os"
14 "slices"
15 "strconv"
16 "strings"
17 "time"
18
19 "charm.land/catwalk/pkg/catwalk"
20)
21
22// Model represents the complete model configuration.
23type Model struct {
24 ID string `json:"id"`
25 CanonicalSlug string `json:"canonical_slug"`
26 HuggingFaceID string `json:"hugging_face_id"`
27 Name string `json:"name"`
28 Created int64 `json:"created"`
29 Description string `json:"description"`
30 ContextLength int64 `json:"context_length"`
31 Architecture Architecture `json:"architecture"`
32 Pricing Pricing `json:"pricing"`
33 TopProvider TopProvider `json:"top_provider"`
34 SupportedParams []string `json:"supported_parameters"`
35}
36
37// Architecture defines the model's architecture details.
38type Architecture struct {
39 Modality string `json:"modality"`
40 InputModalities []string `json:"input_modalities"`
41 OutputModalities []string `json:"output_modalities"`
42 Tokenizer string `json:"tokenizer"`
43 InstructType *string `json:"instruct_type"`
44}
45
46// Pricing contains the pricing information for different operations.
47type Pricing struct {
48 Prompt string `json:"prompt"`
49 Completion string `json:"completion"`
50 Request string `json:"request"`
51 Image string `json:"image"`
52 WebSearch string `json:"web_search"`
53 InternalReasoning string `json:"internal_reasoning"`
54 InputCacheRead string `json:"input_cache_read"`
55 InputCacheWrite string `json:"input_cache_write"`
56}
57
58// TopProvider describes the top provider's capabilities.
59type TopProvider struct {
60 ContextLength int64 `json:"context_length"`
61 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
62 IsModerated bool `json:"is_moderated"`
63}
64
65// Endpoint represents a single endpoint configuration for a model.
66type Endpoint struct {
67 Name string `json:"name"`
68 ContextLength int64 `json:"context_length"`
69 Pricing Pricing `json:"pricing"`
70 ProviderName string `json:"provider_name"`
71 Tag string `json:"tag"`
72 Quantization *string `json:"quantization"`
73 MaxCompletionTokens *int64 `json:"max_completion_tokens"`
74 MaxPromptTokens *int64 `json:"max_prompt_tokens"`
75 SupportedParams []string `json:"supported_parameters"`
76 Status int `json:"status"`
77 UptimeLast30m float64 `json:"uptime_last_30m"`
78}
79
80// EndpointsResponse is the response structure for the endpoints API.
81type EndpointsResponse struct {
82 Data struct {
83 ID string `json:"id"`
84 Name string `json:"name"`
85 Created int64 `json:"created"`
86 Description string `json:"description"`
87 Endpoints []Endpoint `json:"endpoints"`
88 } `json:"data"`
89}
90
91// ModelsResponse is the response structure for the models API.
92type ModelsResponse struct {
93 Data []Model `json:"data"`
94}
95
96// ModelPricing is the pricing structure for a model, detailing costs per
97// million tokens for input and output, both cached and uncached.
98type ModelPricing struct {
99 CostPer1MIn float64 `json:"cost_per_1m_in"`
100 CostPer1MOut float64 `json:"cost_per_1m_out"`
101 CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
102 CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
103}
104
105func roundCost(v float64) float64 {
106 return math.Round(v*1e5) / 1e5
107}
108
109func getPricing(model Model) ModelPricing {
110 pricing := ModelPricing{}
111 costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
112 if err != nil {
113 costPrompt = 0.0
114 }
115 pricing.CostPer1MIn = roundCost(costPrompt * 1_000_000)
116 costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
117 if err != nil {
118 costCompletion = 0.0
119 }
120 pricing.CostPer1MOut = roundCost(costCompletion * 1_000_000)
121
122 costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
123 if err != nil {
124 costPromptCached = 0.0
125 }
126 pricing.CostPer1MInCached = roundCost(costPromptCached * 1_000_000)
127 costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
128 if err != nil {
129 costCompletionCached = 0.0
130 }
131 pricing.CostPer1MOutCached = roundCost(costCompletionCached * 1_000_000)
132 return pricing
133}
134
135func fetchOpenRouterModels() (*ModelsResponse, error) {
136 client := &http.Client{Timeout: 30 * time.Second}
137 req, _ := http.NewRequestWithContext(
138 context.Background(),
139 "GET",
140 "https://openrouter.ai/api/v1/models",
141 nil,
142 )
143 req.Header.Set("User-Agent", "Crush-Client/1.0")
144
145 resp, err := client.Do(req)
146 if err != nil {
147 return nil, err //nolint:wrapcheck
148 }
149 defer resp.Body.Close() //nolint:errcheck
150
151 body, err := io.ReadAll(resp.Body)
152 if err != nil {
153 return nil, fmt.Errorf("unable to read models response body: %w", err)
154 }
155
156 if resp.StatusCode != http.StatusOK {
157 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
158 }
159
160 // for debugging
161 _ = os.MkdirAll("tmp", 0o700)
162 _ = os.WriteFile("tmp/openrouter-response.json", body, 0o600)
163
164 var mr ModelsResponse
165 if err := json.Unmarshal(body, &mr); err != nil {
166 return nil, err //nolint:wrapcheck
167 }
168 return &mr, nil
169}
170
171func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
172 client := &http.Client{Timeout: 30 * time.Second}
173 url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
174 req, _ := http.NewRequestWithContext(
175 context.Background(),
176 "GET",
177 url,
178 nil,
179 )
180 req.Header.Set("User-Agent", "Crush-Client/1.0")
181 resp, err := client.Do(req)
182 if err != nil {
183 return nil, err //nolint:wrapcheck
184 }
185 defer resp.Body.Close() //nolint:errcheck
186 if resp.StatusCode != 200 {
187 body, _ := io.ReadAll(resp.Body)
188 return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
189 }
190 var er EndpointsResponse
191 if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
192 return nil, err //nolint:wrapcheck
193 }
194 return &er, nil
195}
196
197func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
198 if len(endpoints) == 0 {
199 return nil
200 }
201
202 var best *Endpoint
203 for i := range endpoints {
204 endpoint := &endpoints[i]
205 // Skip endpoints with poor status or uptime
206 if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
207 continue
208 }
209
210 if best == nil {
211 best = endpoint
212 continue
213 }
214
215 if isBetterEndpoint(endpoint, best) {
216 best = endpoint
217 }
218 }
219
220 // If no good endpoint found, return the first one as fallback
221 if best == nil {
222 best = &endpoints[0]
223 }
224
225 return best
226}
227
228func isBetterEndpoint(candidate, current *Endpoint) bool {
229 candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
230 currentHasTools := slices.Contains(current.SupportedParams, "tools")
231
232 // Prefer endpoints with tool support over those without
233 if candidateHasTools && !currentHasTools {
234 return true
235 }
236 if !candidateHasTools && currentHasTools {
237 return false
238 }
239
240 // Both have same tool support status, compare other factors
241 if candidate.ContextLength > current.ContextLength {
242 return true
243 }
244 if candidate.ContextLength == current.ContextLength {
245 return candidate.UptimeLast30m > current.UptimeLast30m
246 }
247
248 return false
249}
250
251// This is used to generate the openrouter.json config file.
252func main() {
253 modelsResp, err := fetchOpenRouterModels()
254 if err != nil {
255 log.Fatal("Error fetching OpenRouter models:", err)
256 }
257
258 openRouterProvider := catwalk.Provider{
259 Name: "OpenRouter",
260 ID: "openrouter",
261 APIKey: "$OPENROUTER_API_KEY",
262 APIEndpoint: "https://openrouter.ai/api/v1",
263 Type: catwalk.TypeOpenRouter,
264 DefaultLargeModelID: "anthropic/claude-sonnet-4",
265 DefaultSmallModelID: "anthropic/claude-3.5-haiku",
266 Models: []catwalk.Model{},
267 DefaultHeaders: map[string]string{
268 "HTTP-Referer": "https://charm.land",
269 "X-Title": "Crush",
270 },
271 }
272
273 for _, model := range modelsResp.Data {
274 if model.ContextLength < 20000 {
275 continue
276 }
277 // skip nonβtext models or those without tools
278 if !slices.Contains(model.SupportedParams, "tools") ||
279 !slices.Contains(model.Architecture.InputModalities, "text") ||
280 !slices.Contains(model.Architecture.OutputModalities, "text") {
281 continue
282 }
283
284 // Fetch endpoints for this model to get the best configuration
285 endpointsResp, err := fetchModelEndpoints(model.ID)
286 if err != nil {
287 fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
288 // Fall back to using the original model data
289 pricing := getPricing(model)
290 canReason := slices.Contains(model.SupportedParams, "reasoning")
291 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
292
293 var reasoningLevels []string
294 var defaultReasoning string
295 if canReason {
296 reasoningLevels = []string{"low", "medium", "high"}
297 defaultReasoning = "medium"
298 }
299 m := catwalk.Model{
300 ID: model.ID,
301 Name: model.Name,
302 CostPer1MIn: pricing.CostPer1MIn,
303 CostPer1MOut: pricing.CostPer1MOut,
304 CostPer1MInCached: pricing.CostPer1MInCached,
305 CostPer1MOutCached: pricing.CostPer1MOutCached,
306 ContextWindow: model.ContextLength,
307 CanReason: canReason,
308 DefaultReasoningEffort: defaultReasoning,
309 ReasoningLevels: reasoningLevels,
310 SupportsImages: supportsImages,
311 }
312 if model.TopProvider.MaxCompletionTokens != nil {
313 m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
314 } else {
315 m.DefaultMaxTokens = model.ContextLength / 10
316 }
317 if model.TopProvider.ContextLength > 0 {
318 m.ContextWindow = model.TopProvider.ContextLength
319 }
320 openRouterProvider.Models = append(openRouterProvider.Models, m)
321 continue
322 }
323
324 // Select the best endpoint
325 bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
326 if bestEndpoint == nil {
327 fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
328 continue
329 }
330
331 // Check if the best endpoint supports tools
332 if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
333 continue
334 }
335
336 // Use the best endpoint's configuration
337 pricing := ModelPricing{}
338 costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
339 if err != nil {
340 costPrompt = 0.0
341 }
342 pricing.CostPer1MIn = roundCost(costPrompt * 1_000_000)
343 costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
344 if err != nil {
345 costCompletion = 0.0
346 }
347 pricing.CostPer1MOut = roundCost(costCompletion * 1_000_000)
348
349 costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
350 if err != nil {
351 costPromptCached = 0.0
352 }
353 pricing.CostPer1MInCached = roundCost(costPromptCached * 1_000_000)
354 costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
355 if err != nil {
356 costCompletionCached = 0.0
357 }
358 pricing.CostPer1MOutCached = roundCost(costCompletionCached * 1_000_000)
359
360 canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
361 supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
362
363 var reasoningLevels []string
364 var defaultReasoning string
365 if canReason {
366 reasoningLevels = []string{"low", "medium", "high"}
367 defaultReasoning = "medium"
368 }
369 m := catwalk.Model{
370 ID: model.ID,
371 Name: model.Name,
372 CostPer1MIn: pricing.CostPer1MIn,
373 CostPer1MOut: pricing.CostPer1MOut,
374 CostPer1MInCached: pricing.CostPer1MInCached,
375 CostPer1MOutCached: pricing.CostPer1MOutCached,
376 ContextWindow: bestEndpoint.ContextLength,
377 CanReason: canReason,
378 DefaultReasoningEffort: defaultReasoning,
379 ReasoningLevels: reasoningLevels,
380 SupportsImages: supportsImages,
381 }
382
383 // Set max tokens based on the best endpoint
384 if bestEndpoint.MaxCompletionTokens != nil {
385 m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
386 } else {
387 m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
388 }
389
390 openRouterProvider.Models = append(openRouterProvider.Models, m)
391 fmt.Printf("Added model %s with context window %d from provider %s\n",
392 model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
393 }
394
395 slices.SortFunc(openRouterProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
396 return strings.Compare(a.Name, b.Name)
397 })
398
399 // save the json in internal/providers/config/openrouter.json
400 data, err := json.MarshalIndent(openRouterProvider, "", " ")
401 if err != nil {
402 log.Fatal("Error marshaling OpenRouter provider:", err)
403 }
404 // write to file
405 if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
406 log.Fatal("Error writing OpenRouter provider config:", err)
407 }
408}