main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"strings"
 16	"time"
 17
 18	"github.com/charmbracelet/catwalk/pkg/catwalk"
 19)
 20
 21// Model represents the complete model configuration.
 22type Model struct {
 23	ID              string       `json:"id"`
 24	CanonicalSlug   string       `json:"canonical_slug"`
 25	HuggingFaceID   string       `json:"hugging_face_id"`
 26	Name            string       `json:"name"`
 27	Created         int64        `json:"created"`
 28	Description     string       `json:"description"`
 29	ContextLength   int64        `json:"context_length"`
 30	Architecture    Architecture `json:"architecture"`
 31	Pricing         Pricing      `json:"pricing"`
 32	TopProvider     TopProvider  `json:"top_provider"`
 33	SupportedParams []string     `json:"supported_parameters"`
 34}
 35
 36// Architecture defines the model's architecture details.
 37type Architecture struct {
 38	Modality         string   `json:"modality"`
 39	InputModalities  []string `json:"input_modalities"`
 40	OutputModalities []string `json:"output_modalities"`
 41	Tokenizer        string   `json:"tokenizer"`
 42	InstructType     *string  `json:"instruct_type"`
 43}
 44
 45// Pricing contains the pricing information for different operations.
 46type Pricing struct {
 47	Prompt            string `json:"prompt"`
 48	Completion        string `json:"completion"`
 49	Request           string `json:"request"`
 50	Image             string `json:"image"`
 51	WebSearch         string `json:"web_search"`
 52	InternalReasoning string `json:"internal_reasoning"`
 53	InputCacheRead    string `json:"input_cache_read"`
 54	InputCacheWrite   string `json:"input_cache_write"`
 55}
 56
 57// TopProvider describes the top provider's capabilities.
 58type TopProvider struct {
 59	ContextLength       int64  `json:"context_length"`
 60	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 61	IsModerated         bool   `json:"is_moderated"`
 62}
 63
 64// Endpoint represents a single endpoint configuration for a model.
 65type Endpoint struct {
 66	Name                string   `json:"name"`
 67	ContextLength       int64    `json:"context_length"`
 68	Pricing             Pricing  `json:"pricing"`
 69	ProviderName        string   `json:"provider_name"`
 70	Tag                 string   `json:"tag"`
 71	Quantization        *string  `json:"quantization"`
 72	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 73	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 74	SupportedParams     []string `json:"supported_parameters"`
 75	Status              int      `json:"status"`
 76	UptimeLast30m       float64  `json:"uptime_last_30m"`
 77}
 78
 79// EndpointsResponse is the response structure for the endpoints API.
 80type EndpointsResponse struct {
 81	Data struct {
 82		ID          string     `json:"id"`
 83		Name        string     `json:"name"`
 84		Created     int64      `json:"created"`
 85		Description string     `json:"description"`
 86		Endpoints   []Endpoint `json:"endpoints"`
 87	} `json:"data"`
 88}
 89
 90// ModelsResponse is the response structure for the models API.
 91type ModelsResponse struct {
 92	Data []Model `json:"data"`
 93}
 94
 95// ModelPricing is the pricing structure for a model, detailing costs per
 96// million tokens for input and output, both cached and uncached.
 97type ModelPricing struct {
 98	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 99	CostPer1MOut       float64 `json:"cost_per_1m_out"`
100	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
101	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
102}
103
104func getPricing(model Model) ModelPricing {
105	pricing := ModelPricing{}
106	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
107	if err != nil {
108		costPrompt = 0.0
109	}
110	pricing.CostPer1MIn = costPrompt * 1_000_000
111	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
112	if err != nil {
113		costCompletion = 0.0
114	}
115	pricing.CostPer1MOut = costCompletion * 1_000_000
116
117	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
118	if err != nil {
119		costPromptCached = 0.0
120	}
121	pricing.CostPer1MInCached = costPromptCached * 1_000_000
122	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
123	if err != nil {
124		costCompletionCached = 0.0
125	}
126	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
127	return pricing
128}
129
130func fetchOpenRouterModels() (*ModelsResponse, error) {
131	client := &http.Client{Timeout: 30 * time.Second}
132	req, _ := http.NewRequestWithContext(
133		context.Background(),
134		"GET",
135		"https://openrouter.ai/api/v1/models",
136		nil,
137	)
138	req.Header.Set("User-Agent", "Crush-Client/1.0")
139	resp, err := client.Do(req)
140	if err != nil {
141		return nil, err //nolint:wrapcheck
142	}
143	defer resp.Body.Close() //nolint:errcheck
144	if resp.StatusCode != 200 {
145		body, _ := io.ReadAll(resp.Body)
146		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
147	}
148	var mr ModelsResponse
149	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
150		return nil, err //nolint:wrapcheck
151	}
152	return &mr, nil
153}
154
155func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
156	client := &http.Client{Timeout: 30 * time.Second}
157	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
158	req, _ := http.NewRequestWithContext(
159		context.Background(),
160		"GET",
161		url,
162		nil,
163	)
164	req.Header.Set("User-Agent", "Crush-Client/1.0")
165	resp, err := client.Do(req)
166	if err != nil {
167		return nil, err //nolint:wrapcheck
168	}
169	defer resp.Body.Close() //nolint:errcheck
170	if resp.StatusCode != 200 {
171		body, _ := io.ReadAll(resp.Body)
172		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
173	}
174	var er EndpointsResponse
175	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
176		return nil, err //nolint:wrapcheck
177	}
178	return &er, nil
179}
180
181func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
182	if len(endpoints) == 0 {
183		return nil
184	}
185
186	var best *Endpoint
187	for i := range endpoints {
188		endpoint := &endpoints[i]
189		// Skip endpoints with poor status or uptime
190		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
191			continue
192		}
193
194		if best == nil {
195			best = endpoint
196			continue
197		}
198
199		if isBetterEndpoint(endpoint, best) {
200			best = endpoint
201		}
202	}
203
204	// If no good endpoint found, return the first one as fallback
205	if best == nil {
206		best = &endpoints[0]
207	}
208
209	return best
210}
211
212func isBetterEndpoint(candidate, current *Endpoint) bool {
213	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
214	currentHasTools := slices.Contains(current.SupportedParams, "tools")
215
216	// Prefer endpoints with tool support over those without
217	if candidateHasTools && !currentHasTools {
218		return true
219	}
220	if !candidateHasTools && currentHasTools {
221		return false
222	}
223
224	// Both have same tool support status, compare other factors
225	if candidate.ContextLength > current.ContextLength {
226		return true
227	}
228	if candidate.ContextLength == current.ContextLength {
229		return candidate.UptimeLast30m > current.UptimeLast30m
230	}
231
232	return false
233}
234
235// This is used to generate the openrouter.json config file.
236func main() {
237	modelsResp, err := fetchOpenRouterModels()
238	if err != nil {
239		log.Fatal("Error fetching OpenRouter models:", err)
240	}
241
242	openRouterProvider := catwalk.Provider{
243		Name:                "OpenRouter",
244		ID:                  "openrouter",
245		APIKey:              "$OPENROUTER_API_KEY",
246		APIEndpoint:         "https://openrouter.ai/api/v1",
247		Type:                catwalk.TypeOpenRouter,
248		DefaultLargeModelID: "anthropic/claude-sonnet-4",
249		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
250		Models:              []catwalk.Model{},
251		DefaultHeaders: map[string]string{
252			"HTTP-Referer": "https://charm.land",
253			"X-Title":      "Crush",
254		},
255	}
256
257	for _, model := range modelsResp.Data {
258		if model.ContextLength < 20000 {
259			continue
260		}
261		// skip non‐text models or those without tools
262		if !slices.Contains(model.SupportedParams, "tools") ||
263			!slices.Contains(model.Architecture.InputModalities, "text") ||
264			!slices.Contains(model.Architecture.OutputModalities, "text") {
265			continue
266		}
267
268		// Fetch endpoints for this model to get the best configuration
269		endpointsResp, err := fetchModelEndpoints(model.ID)
270		if err != nil {
271			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
272			// Fall back to using the original model data
273			pricing := getPricing(model)
274			canReason := slices.Contains(model.SupportedParams, "reasoning")
275			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
276
277			var reasoningLevels []string
278			var defaultReasoning string
279			if canReason {
280				reasoningLevels = []string{"low", "medium", "high"}
281				defaultReasoning = "medium"
282			}
283			m := catwalk.Model{
284				ID:                     model.ID,
285				Name:                   model.Name,
286				CostPer1MIn:            pricing.CostPer1MIn,
287				CostPer1MOut:           pricing.CostPer1MOut,
288				CostPer1MInCached:      pricing.CostPer1MInCached,
289				CostPer1MOutCached:     pricing.CostPer1MOutCached,
290				ContextWindow:          model.ContextLength,
291				CanReason:              canReason,
292				DefaultReasoningEffort: defaultReasoning,
293				ReasoningLevels:        reasoningLevels,
294				SupportsImages:         supportsImages,
295			}
296			if model.TopProvider.MaxCompletionTokens != nil {
297				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
298			} else {
299				m.DefaultMaxTokens = model.ContextLength / 10
300			}
301			if model.TopProvider.ContextLength > 0 {
302				m.ContextWindow = model.TopProvider.ContextLength
303			}
304			openRouterProvider.Models = append(openRouterProvider.Models, m)
305			continue
306		}
307
308		// Select the best endpoint
309		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
310		if bestEndpoint == nil {
311			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
312			continue
313		}
314
315		// Check if the best endpoint supports tools
316		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
317			continue
318		}
319
320		// Use the best endpoint's configuration
321		pricing := ModelPricing{}
322		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
323		if err != nil {
324			costPrompt = 0.0
325		}
326		pricing.CostPer1MIn = costPrompt * 1_000_000
327		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
328		if err != nil {
329			costCompletion = 0.0
330		}
331		pricing.CostPer1MOut = costCompletion * 1_000_000
332
333		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
334		if err != nil {
335			costPromptCached = 0.0
336		}
337		pricing.CostPer1MInCached = costPromptCached * 1_000_000
338		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
339		if err != nil {
340			costCompletionCached = 0.0
341		}
342		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
343
344		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
345		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
346
347		var reasoningLevels []string
348		var defaultReasoning string
349		if canReason {
350			reasoningLevels = []string{"low", "medium", "high"}
351			defaultReasoning = "medium"
352		}
353		m := catwalk.Model{
354			ID:                     model.ID,
355			Name:                   model.Name,
356			CostPer1MIn:            pricing.CostPer1MIn,
357			CostPer1MOut:           pricing.CostPer1MOut,
358			CostPer1MInCached:      pricing.CostPer1MInCached,
359			CostPer1MOutCached:     pricing.CostPer1MOutCached,
360			ContextWindow:          bestEndpoint.ContextLength,
361			CanReason:              canReason,
362			DefaultReasoningEffort: defaultReasoning,
363			ReasoningLevels:        reasoningLevels,
364			SupportsImages:         supportsImages,
365		}
366
367		// Set max tokens based on the best endpoint
368		if bestEndpoint.MaxCompletionTokens != nil {
369			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
370		} else {
371			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
372		}
373
374		openRouterProvider.Models = append(openRouterProvider.Models, m)
375		fmt.Printf("Added model %s with context window %d from provider %s\n",
376			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
377	}
378
379	slices.SortFunc(openRouterProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
380		return strings.Compare(a.Name, b.Name)
381	})
382
383	// save the json in internal/providers/config/openrouter.json
384	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
385	if err != nil {
386		log.Fatal("Error marshaling OpenRouter provider:", err)
387	}
388	// write to file
389	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
390		log.Fatal("Error writing OpenRouter provider config:", err)
391	}
392}