main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"time"
 16
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18)
 19
 20// Model represents the complete model configuration.
 21type Model struct {
 22	ID              string       `json:"id"`
 23	CanonicalSlug   string       `json:"canonical_slug"`
 24	HuggingFaceID   string       `json:"hugging_face_id"`
 25	Name            string       `json:"name"`
 26	Created         int64        `json:"created"`
 27	Description     string       `json:"description"`
 28	ContextLength   int64        `json:"context_length"`
 29	Architecture    Architecture `json:"architecture"`
 30	Pricing         Pricing      `json:"pricing"`
 31	TopProvider     TopProvider  `json:"top_provider"`
 32	SupportedParams []string     `json:"supported_parameters"`
 33}
 34
 35// Architecture defines the model's architecture details.
 36type Architecture struct {
 37	Modality         string   `json:"modality"`
 38	InputModalities  []string `json:"input_modalities"`
 39	OutputModalities []string `json:"output_modalities"`
 40	Tokenizer        string   `json:"tokenizer"`
 41	InstructType     *string  `json:"instruct_type"`
 42}
 43
 44// Pricing contains the pricing information for different operations.
 45type Pricing struct {
 46	Prompt            string `json:"prompt"`
 47	Completion        string `json:"completion"`
 48	Request           string `json:"request"`
 49	Image             string `json:"image"`
 50	WebSearch         string `json:"web_search"`
 51	InternalReasoning string `json:"internal_reasoning"`
 52	InputCacheRead    string `json:"input_cache_read"`
 53	InputCacheWrite   string `json:"input_cache_write"`
 54}
 55
 56// TopProvider describes the top provider's capabilities.
 57type TopProvider struct {
 58	ContextLength       int64  `json:"context_length"`
 59	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 60	IsModerated         bool   `json:"is_moderated"`
 61}
 62
 63// Endpoint represents a single endpoint configuration for a model.
 64type Endpoint struct {
 65	Name                string   `json:"name"`
 66	ContextLength       int64    `json:"context_length"`
 67	Pricing             Pricing  `json:"pricing"`
 68	ProviderName        string   `json:"provider_name"`
 69	Tag                 string   `json:"tag"`
 70	Quantization        *string  `json:"quantization"`
 71	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 72	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 73	SupportedParams     []string `json:"supported_parameters"`
 74	Status              int      `json:"status"`
 75	UptimeLast30m       float64  `json:"uptime_last_30m"`
 76}
 77
 78// EndpointsResponse is the response structure for the endpoints API.
 79type EndpointsResponse struct {
 80	Data struct {
 81		ID          string     `json:"id"`
 82		Name        string     `json:"name"`
 83		Created     int64      `json:"created"`
 84		Description string     `json:"description"`
 85		Endpoints   []Endpoint `json:"endpoints"`
 86	} `json:"data"`
 87}
 88
 89// ModelsResponse is the response structure for the models API.
 90type ModelsResponse struct {
 91	Data []Model `json:"data"`
 92}
 93
 94// ModelPricing is the pricing structure for a model, detailing costs per
 95// million tokens for input and output, both cached and uncached.
 96type ModelPricing struct {
 97	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 98	CostPer1MOut       float64 `json:"cost_per_1m_out"`
 99	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
100	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
101}
102
103func getPricing(model Model) ModelPricing {
104	pricing := ModelPricing{}
105	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
106	if err != nil {
107		costPrompt = 0.0
108	}
109	pricing.CostPer1MIn = costPrompt * 1_000_000
110	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
111	if err != nil {
112		costCompletion = 0.0
113	}
114	pricing.CostPer1MOut = costCompletion * 1_000_000
115
116	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
117	if err != nil {
118		costPromptCached = 0.0
119	}
120	pricing.CostPer1MInCached = costPromptCached * 1_000_000
121	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
122	if err != nil {
123		costCompletionCached = 0.0
124	}
125	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
126	return pricing
127}
128
129func fetchOpenRouterModels() (*ModelsResponse, error) {
130	client := &http.Client{Timeout: 30 * time.Second}
131	req, _ := http.NewRequestWithContext(
132		context.Background(),
133		"GET",
134		"https://openrouter.ai/api/v1/models",
135		nil,
136	)
137	req.Header.Set("User-Agent", "Crush-Client/1.0")
138	resp, err := client.Do(req)
139	if err != nil {
140		return nil, err //nolint:wrapcheck
141	}
142	defer resp.Body.Close() //nolint:errcheck
143	if resp.StatusCode != 200 {
144		body, _ := io.ReadAll(resp.Body)
145		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
146	}
147	var mr ModelsResponse
148	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
149		return nil, err //nolint:wrapcheck
150	}
151	return &mr, nil
152}
153
154func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
155	client := &http.Client{Timeout: 30 * time.Second}
156	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
157	req, _ := http.NewRequestWithContext(
158		context.Background(),
159		"GET",
160		url,
161		nil,
162	)
163	req.Header.Set("User-Agent", "Crush-Client/1.0")
164	resp, err := client.Do(req)
165	if err != nil {
166		return nil, err //nolint:wrapcheck
167	}
168	defer resp.Body.Close() //nolint:errcheck
169	if resp.StatusCode != 200 {
170		body, _ := io.ReadAll(resp.Body)
171		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
172	}
173	var er EndpointsResponse
174	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
175		return nil, err //nolint:wrapcheck
176	}
177	return &er, nil
178}
179
180func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
181	if len(endpoints) == 0 {
182		return nil
183	}
184
185	var best *Endpoint
186	for i := range endpoints {
187		endpoint := &endpoints[i]
188		// Skip endpoints with poor status or uptime
189		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
190			continue
191		}
192
193		if best == nil {
194			best = endpoint
195			continue
196		}
197
198		if isBetterEndpoint(endpoint, best) {
199			best = endpoint
200		}
201	}
202
203	// If no good endpoint found, return the first one as fallback
204	if best == nil {
205		best = &endpoints[0]
206	}
207
208	return best
209}
210
211func isBetterEndpoint(candidate, current *Endpoint) bool {
212	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
213	currentHasTools := slices.Contains(current.SupportedParams, "tools")
214
215	// Prefer endpoints with tool support over those without
216	if candidateHasTools && !currentHasTools {
217		return true
218	}
219	if !candidateHasTools && currentHasTools {
220		return false
221	}
222
223	// Both have same tool support status, compare other factors
224	if candidate.ContextLength > current.ContextLength {
225		return true
226	}
227	if candidate.ContextLength == current.ContextLength {
228		return candidate.UptimeLast30m > current.UptimeLast30m
229	}
230
231	return false
232}
233
234// This is used to generate the openrouter.json config file.
235func main() {
236	modelsResp, err := fetchOpenRouterModels()
237	if err != nil {
238		log.Fatal("Error fetching OpenRouter models:", err)
239	}
240
241	openRouterProvider := catwalk.Provider{
242		Name:                "OpenRouter",
243		ID:                  "openrouter",
244		APIKey:              "$OPENROUTER_API_KEY",
245		APIEndpoint:         "https://openrouter.ai/api/v1",
246		Type:                catwalk.TypeOpenAI,
247		DefaultLargeModelID: "anthropic/claude-sonnet-4",
248		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
249		Models:              []catwalk.Model{},
250	}
251
252	for _, model := range modelsResp.Data {
253		// skip non‐text models or those without tools
254		if !slices.Contains(model.SupportedParams, "tools") ||
255			!slices.Contains(model.Architecture.InputModalities, "text") ||
256			!slices.Contains(model.Architecture.OutputModalities, "text") {
257			continue
258		}
259
260		// Fetch endpoints for this model to get the best configuration
261		endpointsResp, err := fetchModelEndpoints(model.ID)
262		if err != nil {
263			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
264			// Fall back to using the original model data
265			pricing := getPricing(model)
266			canReason := slices.Contains(model.SupportedParams, "reasoning")
267			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
268
269			m := catwalk.Model{
270				ID:                 model.ID,
271				Name:               model.Name,
272				CostPer1MIn:        pricing.CostPer1MIn,
273				CostPer1MOut:       pricing.CostPer1MOut,
274				CostPer1MInCached:  pricing.CostPer1MInCached,
275				CostPer1MOutCached: pricing.CostPer1MOutCached,
276				ContextWindow:      model.ContextLength,
277				CanReason:          canReason,
278				SupportsImages:     supportsImages,
279			}
280			if model.TopProvider.MaxCompletionTokens != nil {
281				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
282			} else {
283				m.DefaultMaxTokens = model.ContextLength / 10
284			}
285			if model.TopProvider.ContextLength > 0 {
286				m.ContextWindow = model.TopProvider.ContextLength
287			}
288			openRouterProvider.Models = append(openRouterProvider.Models, m)
289			continue
290		}
291
292		// Select the best endpoint
293		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
294		if bestEndpoint == nil {
295			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
296			continue
297		}
298
299		// Check if the best endpoint supports tools
300		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
301			continue
302		}
303
304		// Use the best endpoint's configuration
305		pricing := ModelPricing{}
306		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
307		if err != nil {
308			costPrompt = 0.0
309		}
310		pricing.CostPer1MIn = costPrompt * 1_000_000
311		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
312		if err != nil {
313			costCompletion = 0.0
314		}
315		pricing.CostPer1MOut = costCompletion * 1_000_000
316
317		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
318		if err != nil {
319			costPromptCached = 0.0
320		}
321		pricing.CostPer1MInCached = costPromptCached * 1_000_000
322		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
323		if err != nil {
324			costCompletionCached = 0.0
325		}
326		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
327
328		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
329		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
330
331		m := catwalk.Model{
332			ID:                 model.ID,
333			Name:               model.Name,
334			CostPer1MIn:        pricing.CostPer1MIn,
335			CostPer1MOut:       pricing.CostPer1MOut,
336			CostPer1MInCached:  pricing.CostPer1MInCached,
337			CostPer1MOutCached: pricing.CostPer1MOutCached,
338			ContextWindow:      bestEndpoint.ContextLength,
339			CanReason:          canReason,
340			SupportsImages:     supportsImages,
341		}
342
343		// Set max tokens based on the best endpoint
344		if bestEndpoint.MaxCompletionTokens != nil {
345			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
346		} else {
347			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
348		}
349
350		openRouterProvider.Models = append(openRouterProvider.Models, m)
351		fmt.Printf("Added model %s with context window %d from provider %s\n",
352			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
353	}
354
355	// save the json in internal/providers/config/openrouter.json
356	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
357	if err != nil {
358		log.Fatal("Error marshaling OpenRouter provider:", err)
359	}
360	// write to file
361	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
362		log.Fatal("Error writing OpenRouter provider config:", err)
363	}
364}