main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"time"
 16
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18)
 19
 20// Model represents the complete model configuration.
 21type Model struct {
 22	ID              string       `json:"id"`
 23	CanonicalSlug   string       `json:"canonical_slug"`
 24	HuggingFaceID   string       `json:"hugging_face_id"`
 25	Name            string       `json:"name"`
 26	Created         int64        `json:"created"`
 27	Description     string       `json:"description"`
 28	ContextLength   int64        `json:"context_length"`
 29	Architecture    Architecture `json:"architecture"`
 30	Pricing         Pricing      `json:"pricing"`
 31	TopProvider     TopProvider  `json:"top_provider"`
 32	SupportedParams []string     `json:"supported_parameters"`
 33}
 34
 35// Architecture defines the model's architecture details.
 36type Architecture struct {
 37	Modality         string   `json:"modality"`
 38	InputModalities  []string `json:"input_modalities"`
 39	OutputModalities []string `json:"output_modalities"`
 40	Tokenizer        string   `json:"tokenizer"`
 41	InstructType     *string  `json:"instruct_type"`
 42}
 43
 44// Pricing contains the pricing information for different operations.
 45type Pricing struct {
 46	Prompt            string `json:"prompt"`
 47	Completion        string `json:"completion"`
 48	Request           string `json:"request"`
 49	Image             string `json:"image"`
 50	WebSearch         string `json:"web_search"`
 51	InternalReasoning string `json:"internal_reasoning"`
 52	InputCacheRead    string `json:"input_cache_read"`
 53	InputCacheWrite   string `json:"input_cache_write"`
 54}
 55
 56// TopProvider describes the top provider's capabilities.
 57type TopProvider struct {
 58	ContextLength       int64  `json:"context_length"`
 59	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 60	IsModerated         bool   `json:"is_moderated"`
 61}
 62
 63// Endpoint represents a single endpoint configuration for a model.
 64type Endpoint struct {
 65	Name                string   `json:"name"`
 66	ContextLength       int64    `json:"context_length"`
 67	Pricing             Pricing  `json:"pricing"`
 68	ProviderName        string   `json:"provider_name"`
 69	Tag                 string   `json:"tag"`
 70	Quantization        *string  `json:"quantization"`
 71	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 72	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 73	SupportedParams     []string `json:"supported_parameters"`
 74	Status              int      `json:"status"`
 75	UptimeLast30m       float64  `json:"uptime_last_30m"`
 76}
 77
 78// EndpointsResponse is the response structure for the endpoints API.
 79type EndpointsResponse struct {
 80	Data struct {
 81		ID          string     `json:"id"`
 82		Name        string     `json:"name"`
 83		Created     int64      `json:"created"`
 84		Description string     `json:"description"`
 85		Endpoints   []Endpoint `json:"endpoints"`
 86	} `json:"data"`
 87}
 88
 89// ModelsResponse is the response structure for the models API.
 90type ModelsResponse struct {
 91	Data []Model `json:"data"`
 92}
 93
 94// ModelPricing is the pricing structure for a model, detailing costs per
 95// million tokens for input and output, both cached and uncached.
 96type ModelPricing struct {
 97	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 98	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 99	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
100	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
101}
102
103func getPricing(model Model) ModelPricing {
104	pricing := ModelPricing{}
105	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
106	if err != nil {
107		costPrompt = 0.0
108	}
109	pricing.CostPer1MIn = costPrompt * 1_000_000
110	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
111	if err != nil {
112		costCompletion = 0.0
113	}
114	pricing.CostPer1MOut = costCompletion * 1_000_000
115
116	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
117	if err != nil {
118		costPromptCached = 0.0
119	}
120	pricing.CostPer1MInCached = costPromptCached * 1_000_000
121	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
122	if err != nil {
123		costCompletionCached = 0.0
124	}
125	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
126	return pricing
127}
128
129func fetchOpenRouterModels() (*ModelsResponse, error) {
130	client := &http.Client{Timeout: 30 * time.Second}
131	req, _ := http.NewRequestWithContext(
132		context.Background(),
133		"GET",
134		"https://openrouter.ai/api/v1/models",
135		nil,
136	)
137	req.Header.Set("User-Agent", "Crush-Client/1.0")
138	resp, err := client.Do(req)
139	if err != nil {
140		return nil, err //nolint:wrapcheck
141	}
142	defer resp.Body.Close() //nolint:errcheck
143	if resp.StatusCode != 200 {
144		body, _ := io.ReadAll(resp.Body)
145		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
146	}
147	var mr ModelsResponse
148	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
149		return nil, err //nolint:wrapcheck
150	}
151	return &mr, nil
152}
153
154func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
155	client := &http.Client{Timeout: 30 * time.Second}
156	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
157	req, _ := http.NewRequestWithContext(
158		context.Background(),
159		"GET",
160		url,
161		nil,
162	)
163	req.Header.Set("User-Agent", "Crush-Client/1.0")
164	resp, err := client.Do(req)
165	if err != nil {
166		return nil, err //nolint:wrapcheck
167	}
168	defer resp.Body.Close() //nolint:errcheck
169	if resp.StatusCode != 200 {
170		body, _ := io.ReadAll(resp.Body)
171		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
172	}
173	var er EndpointsResponse
174	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
175		return nil, err //nolint:wrapcheck
176	}
177	return &er, nil
178}
179
180func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
181	if len(endpoints) == 0 {
182		return nil
183	}
184
185	var best *Endpoint
186	for i := range endpoints {
187		endpoint := &endpoints[i]
188		// Skip endpoints with poor status or uptime
189		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
190			continue
191		}
192
193		if best == nil {
194			best = endpoint
195			continue
196		}
197
198		if isBetterEndpoint(endpoint, best) {
199			best = endpoint
200		}
201	}
202
203	// If no good endpoint found, return the first one as fallback
204	if best == nil {
205		best = &endpoints[0]
206	}
207
208	return best
209}
210
211func isBetterEndpoint(candidate, current *Endpoint) bool {
212	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
213	currentHasTools := slices.Contains(current.SupportedParams, "tools")
214
215	// Prefer endpoints with tool support over those without
216	if candidateHasTools && !currentHasTools {
217		return true
218	}
219	if !candidateHasTools && currentHasTools {
220		return false
221	}
222
223	// Both have same tool support status, compare other factors
224	if candidate.ContextLength > current.ContextLength {
225		return true
226	}
227	if candidate.ContextLength == current.ContextLength {
228		return candidate.UptimeLast30m > current.UptimeLast30m
229	}
230
231	return false
232}
233
234// This is used to generate the openrouter.json config file.
235func main() {
236	modelsResp, err := fetchOpenRouterModels()
237	if err != nil {
238		log.Fatal("Error fetching OpenRouter models:", err)
239	}
240
241	openRouterProvider := catwalk.Provider{
242		Name:                "OpenRouter",
243		ID:                  "openrouter",
244		APIKey:              "$OPENROUTER_API_KEY",
245		APIEndpoint:         "https://openrouter.ai/api/v1",
246		Type:                catwalk.TypeOpenAI,
247		DefaultLargeModelID: "anthropic/claude-sonnet-4",
248		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
249		Models:              []catwalk.Model{},
250		DefaultHeaders: map[string]string{
251			"HTTP-Referer": "https://charm.land",
252			"X-Title":      "Crush",
253		},
254	}
255
256	for _, model := range modelsResp.Data {
257		// skip non‐text models or those without tools
258		if !slices.Contains(model.SupportedParams, "tools") ||
259			!slices.Contains(model.Architecture.InputModalities, "text") ||
260			!slices.Contains(model.Architecture.OutputModalities, "text") {
261			continue
262		}
263
264		// Fetch endpoints for this model to get the best configuration
265		endpointsResp, err := fetchModelEndpoints(model.ID)
266		if err != nil {
267			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
268			// Fall back to using the original model data
269			pricing := getPricing(model)
270			canReason := slices.Contains(model.SupportedParams, "reasoning")
271			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
272
273			m := catwalk.Model{
274				ID:                 model.ID,
275				Name:               model.Name,
276				CostPer1MIn:        pricing.CostPer1MIn,
277				CostPer1MOut:       pricing.CostPer1MOut,
278				CostPer1MInCached:  pricing.CostPer1MInCached,
279				CostPer1MOutCached: pricing.CostPer1MOutCached,
280				ContextWindow:      model.ContextLength,
281				CanReason:          canReason,
282				HasReasoningEffort: canReason,
283				SupportsImages:     supportsImages,
284			}
285			if model.TopProvider.MaxCompletionTokens != nil {
286				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
287			} else {
288				m.DefaultMaxTokens = model.ContextLength / 10
289			}
290			if model.TopProvider.ContextLength > 0 {
291				m.ContextWindow = model.TopProvider.ContextLength
292			}
293			openRouterProvider.Models = append(openRouterProvider.Models, m)
294			continue
295		}
296
297		// Select the best endpoint
298		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
299		if bestEndpoint == nil {
300			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
301			continue
302		}
303
304		// Check if the best endpoint supports tools
305		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
306			continue
307		}
308
309		// Use the best endpoint's configuration
310		pricing := ModelPricing{}
311		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
312		if err != nil {
313			costPrompt = 0.0
314		}
315		pricing.CostPer1MIn = costPrompt * 1_000_000
316		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
317		if err != nil {
318			costCompletion = 0.0
319		}
320		pricing.CostPer1MOut = costCompletion * 1_000_000
321
322		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
323		if err != nil {
324			costPromptCached = 0.0
325		}
326		pricing.CostPer1MInCached = costPromptCached * 1_000_000
327		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
328		if err != nil {
329			costCompletionCached = 0.0
330		}
331		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
332
333		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
334		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
335
336		m := catwalk.Model{
337			ID:                 model.ID,
338			Name:               model.Name,
339			CostPer1MIn:        pricing.CostPer1MIn,
340			CostPer1MOut:       pricing.CostPer1MOut,
341			CostPer1MInCached:  pricing.CostPer1MInCached,
342			CostPer1MOutCached: pricing.CostPer1MOutCached,
343			ContextWindow:      bestEndpoint.ContextLength,
344			CanReason:          canReason,
345			HasReasoningEffort: canReason,
346			SupportsImages:     supportsImages,
347		}
348
349		// Set max tokens based on the best endpoint
350		if bestEndpoint.MaxCompletionTokens != nil {
351			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
352		} else {
353			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
354		}
355
356		openRouterProvider.Models = append(openRouterProvider.Models, m)
357		fmt.Printf("Added model %s with context window %d from provider %s\n",
358			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
359	}
360
361	// save the json in internal/providers/config/openrouter.json
362	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
363	if err != nil {
364		log.Fatal("Error marshaling OpenRouter provider:", err)
365	}
366	// write to file
367	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
368		log.Fatal("Error writing OpenRouter provider config:", err)
369	}
370}