main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"time"
 16
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18)
 19
 20// Model represents the complete model configuration.
 21type Model struct {
 22	ID              string       `json:"id"`
 23	CanonicalSlug   string       `json:"canonical_slug"`
 24	HuggingFaceID   string       `json:"hugging_face_id"`
 25	Name            string       `json:"name"`
 26	Created         int64        `json:"created"`
 27	Description     string       `json:"description"`
 28	ContextLength   int64        `json:"context_length"`
 29	Architecture    Architecture `json:"architecture"`
 30	Pricing         Pricing      `json:"pricing"`
 31	TopProvider     TopProvider  `json:"top_provider"`
 32	SupportedParams []string     `json:"supported_parameters"`
 33}
 34
 35// Architecture defines the model's architecture details.
 36type Architecture struct {
 37	Modality         string   `json:"modality"`
 38	InputModalities  []string `json:"input_modalities"`
 39	OutputModalities []string `json:"output_modalities"`
 40	Tokenizer        string   `json:"tokenizer"`
 41	InstructType     *string  `json:"instruct_type"`
 42}
 43
 44// Pricing contains the pricing information for different operations.
 45type Pricing struct {
 46	Prompt            string `json:"prompt"`
 47	Completion        string `json:"completion"`
 48	Request           string `json:"request"`
 49	Image             string `json:"image"`
 50	WebSearch         string `json:"web_search"`
 51	InternalReasoning string `json:"internal_reasoning"`
 52	InputCacheRead    string `json:"input_cache_read"`
 53	InputCacheWrite   string `json:"input_cache_write"`
 54}
 55
 56// TopProvider describes the top provider's capabilities.
 57type TopProvider struct {
 58	ContextLength       int64  `json:"context_length"`
 59	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 60	IsModerated         bool   `json:"is_moderated"`
 61}
 62
 63// Endpoint represents a single endpoint configuration for a model.
 64type Endpoint struct {
 65	Name                string   `json:"name"`
 66	ContextLength       int64    `json:"context_length"`
 67	Pricing             Pricing  `json:"pricing"`
 68	ProviderName        string   `json:"provider_name"`
 69	Tag                 string   `json:"tag"`
 70	Quantization        *string  `json:"quantization"`
 71	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 72	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 73	SupportedParams     []string `json:"supported_parameters"`
 74	Status              int      `json:"status"`
 75	UptimeLast30m       float64  `json:"uptime_last_30m"`
 76}
 77
 78// EndpointsResponse is the response structure for the endpoints API.
 79type EndpointsResponse struct {
 80	Data struct {
 81		ID          string     `json:"id"`
 82		Name        string     `json:"name"`
 83		Created     int64      `json:"created"`
 84		Description string     `json:"description"`
 85		Endpoints   []Endpoint `json:"endpoints"`
 86	} `json:"data"`
 87}
 88
 89// ModelsResponse is the response structure for the models API.
 90type ModelsResponse struct {
 91	Data []Model `json:"data"`
 92}
 93
 94// ModelPricing is the pricing structure for a model, detailing costs per
 95// million tokens for input and output, both cached and uncached.
 96type ModelPricing struct {
 97	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 98	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 99	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
100	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
101}
102
103func getPricing(model Model) ModelPricing {
104	pricing := ModelPricing{}
105	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
106	if err != nil {
107		costPrompt = 0.0
108	}
109	pricing.CostPer1MIn = costPrompt * 1_000_000
110	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
111	if err != nil {
112		costCompletion = 0.0
113	}
114	pricing.CostPer1MOut = costCompletion * 1_000_000
115
116	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
117	if err != nil {
118		costPromptCached = 0.0
119	}
120	pricing.CostPer1MInCached = costPromptCached * 1_000_000
121	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
122	if err != nil {
123		costCompletionCached = 0.0
124	}
125	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
126	return pricing
127}
128
129func fetchOpenRouterModels() (*ModelsResponse, error) {
130	client := &http.Client{Timeout: 30 * time.Second}
131	req, _ := http.NewRequestWithContext(
132		context.Background(),
133		"GET",
134		"https://openrouter.ai/api/v1/models",
135		nil,
136	)
137	req.Header.Set("User-Agent", "Crush-Client/1.0")
138	resp, err := client.Do(req)
139	if err != nil {
140		return nil, err //nolint:wrapcheck
141	}
142	defer resp.Body.Close() //nolint:errcheck
143	if resp.StatusCode != 200 {
144		body, _ := io.ReadAll(resp.Body)
145		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
146	}
147	var mr ModelsResponse
148	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
149		return nil, err //nolint:wrapcheck
150	}
151	return &mr, nil
152}
153
154func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
155	client := &http.Client{Timeout: 30 * time.Second}
156	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
157	req, _ := http.NewRequestWithContext(
158		context.Background(),
159		"GET",
160		url,
161		nil,
162	)
163	req.Header.Set("User-Agent", "Crush-Client/1.0")
164	resp, err := client.Do(req)
165	if err != nil {
166		return nil, err //nolint:wrapcheck
167	}
168	defer resp.Body.Close() //nolint:errcheck
169	if resp.StatusCode != 200 {
170		body, _ := io.ReadAll(resp.Body)
171		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
172	}
173	var er EndpointsResponse
174	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
175		return nil, err //nolint:wrapcheck
176	}
177	return &er, nil
178}
179
180func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
181	if len(endpoints) == 0 {
182		return nil
183	}
184
185	var best *Endpoint
186	for i := range endpoints {
187		endpoint := &endpoints[i]
188		// Skip endpoints with poor status or uptime
189		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
190			continue
191		}
192
193		if best == nil {
194			best = endpoint
195			continue
196		}
197
198		// Prefer higher context length
199		if endpoint.ContextLength > best.ContextLength {
200			best = endpoint
201		} else if endpoint.ContextLength == best.ContextLength {
202			// If context length is the same, prefer better uptime
203			if endpoint.UptimeLast30m > best.UptimeLast30m {
204				best = endpoint
205			}
206		}
207	}
208
209	// If no good endpoint found, return the first one as fallback
210	if best == nil {
211		best = &endpoints[0]
212	}
213
214	return best
215}
216
217// This is used to generate the openrouter.json config file.
218func main() {
219	modelsResp, err := fetchOpenRouterModels()
220	if err != nil {
221		log.Fatal("Error fetching OpenRouter models:", err)
222	}
223
224	openRouterProvider := catwalk.Provider{
225		Name:                "OpenRouter",
226		ID:                  "openrouter",
227		APIKey:              "$OPENROUTER_API_KEY",
228		APIEndpoint:         "https://openrouter.ai/api/v1",
229		Type:                catwalk.TypeOpenAI,
230		DefaultLargeModelID: "anthropic/claude-sonnet-4",
231		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
232		Models:              []catwalk.Model{},
233	}
234
235	for _, model := range modelsResp.Data {
236		// skip non‐text models or those without tools
237		if !slices.Contains(model.SupportedParams, "tools") ||
238			!slices.Contains(model.Architecture.InputModalities, "text") ||
239			!slices.Contains(model.Architecture.OutputModalities, "text") {
240			continue
241		}
242
243		// Fetch endpoints for this model to get the best configuration
244		endpointsResp, err := fetchModelEndpoints(model.ID)
245		if err != nil {
246			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
247			// Fall back to using the original model data
248			pricing := getPricing(model)
249			canReason := slices.Contains(model.SupportedParams, "reasoning")
250			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
251
252			m := catwalk.Model{
253				ID:                 model.ID,
254				Name:               model.Name,
255				CostPer1MIn:        pricing.CostPer1MIn,
256				CostPer1MOut:       pricing.CostPer1MOut,
257				CostPer1MInCached:  pricing.CostPer1MInCached,
258				CostPer1MOutCached: pricing.CostPer1MOutCached,
259				ContextWindow:      model.ContextLength,
260				CanReason:          canReason,
261				SupportsImages:     supportsImages,
262			}
263			if model.TopProvider.MaxCompletionTokens != nil {
264				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
265			} else {
266				m.DefaultMaxTokens = model.ContextLength / 10
267			}
268			if model.TopProvider.ContextLength > 0 {
269				m.ContextWindow = model.TopProvider.ContextLength
270			}
271			openRouterProvider.Models = append(openRouterProvider.Models, m)
272			continue
273		}
274
275		// Select the best endpoint
276		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
277		if bestEndpoint == nil {
278			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
279			continue
280		}
281
282		// Check if the best endpoint supports tools
283		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
284			continue
285		}
286
287		// Use the best endpoint's configuration
288		pricing := ModelPricing{}
289		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
290		if err != nil {
291			costPrompt = 0.0
292		}
293		pricing.CostPer1MIn = costPrompt * 1_000_000
294		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
295		if err != nil {
296			costCompletion = 0.0
297		}
298		pricing.CostPer1MOut = costCompletion * 1_000_000
299
300		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
301		if err != nil {
302			costPromptCached = 0.0
303		}
304		pricing.CostPer1MInCached = costPromptCached * 1_000_000
305		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
306		if err != nil {
307			costCompletionCached = 0.0
308		}
309		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
310
311		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
312		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
313
314		m := catwalk.Model{
315			ID:                 model.ID,
316			Name:               model.Name,
317			CostPer1MIn:        pricing.CostPer1MIn,
318			CostPer1MOut:       pricing.CostPer1MOut,
319			CostPer1MInCached:  pricing.CostPer1MInCached,
320			CostPer1MOutCached: pricing.CostPer1MOutCached,
321			ContextWindow:      bestEndpoint.ContextLength,
322			CanReason:          canReason,
323			SupportsImages:     supportsImages,
324		}
325
326		// Set max tokens based on the best endpoint
327		if bestEndpoint.MaxCompletionTokens != nil {
328			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
329		} else {
330			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
331		}
332
333		openRouterProvider.Models = append(openRouterProvider.Models, m)
334		fmt.Printf("Added model %s with context window %d from provider %s\n",
335			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
336	}
337
338	// save the json in internal/providers/config/openrouter.json
339	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
340	if err != nil {
341		log.Fatal("Error marshaling OpenRouter provider:", err)
342	}
343	// write to file
344	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
345		log.Fatal("Error writing OpenRouter provider config:", err)
346	}
347}