1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"net/http"
 12	"os"
 13	"slices"
 14	"strconv"
 15	"strings"
 16	"time"
 17
 18	"github.com/charmbracelet/catwalk/pkg/catwalk"
 19)
 20
 21// Model represents the complete model configuration.
 22type Model struct {
 23	ID              string       `json:"id"`
 24	CanonicalSlug   string       `json:"canonical_slug"`
 25	HuggingFaceID   string       `json:"hugging_face_id"`
 26	Name            string       `json:"name"`
 27	Created         int64        `json:"created"`
 28	Description     string       `json:"description"`
 29	ContextLength   int64        `json:"context_length"`
 30	Architecture    Architecture `json:"architecture"`
 31	Pricing         Pricing      `json:"pricing"`
 32	TopProvider     TopProvider  `json:"top_provider"`
 33	SupportedParams []string     `json:"supported_parameters"`
 34}
 35
 36// Architecture defines the model's architecture details.
 37type Architecture struct {
 38	Modality         string   `json:"modality"`
 39	InputModalities  []string `json:"input_modalities"`
 40	OutputModalities []string `json:"output_modalities"`
 41	Tokenizer        string   `json:"tokenizer"`
 42	InstructType     *string  `json:"instruct_type"`
 43}
 44
 45// Pricing contains the pricing information for different operations.
 46type Pricing struct {
 47	Prompt            string `json:"prompt"`
 48	Completion        string `json:"completion"`
 49	Request           string `json:"request"`
 50	Image             string `json:"image"`
 51	WebSearch         string `json:"web_search"`
 52	InternalReasoning string `json:"internal_reasoning"`
 53	InputCacheRead    string `json:"input_cache_read"`
 54	InputCacheWrite   string `json:"input_cache_write"`
 55}
 56
 57// TopProvider describes the top provider's capabilities.
 58type TopProvider struct {
 59	ContextLength       int64  `json:"context_length"`
 60	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 61	IsModerated         bool   `json:"is_moderated"`
 62}
 63
 64// Endpoint represents a single endpoint configuration for a model.
 65type Endpoint struct {
 66	Name                string   `json:"name"`
 67	ContextLength       int64    `json:"context_length"`
 68	Pricing             Pricing  `json:"pricing"`
 69	ProviderName        string   `json:"provider_name"`
 70	Tag                 string   `json:"tag"`
 71	Quantization        *string  `json:"quantization"`
 72	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 73	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 74	SupportedParams     []string `json:"supported_parameters"`
 75	Status              int      `json:"status"`
 76	UptimeLast30m       float64  `json:"uptime_last_30m"`
 77}
 78
 79// EndpointsResponse is the response structure for the endpoints API.
 80type EndpointsResponse struct {
 81	Data struct {
 82		ID          string     `json:"id"`
 83		Name        string     `json:"name"`
 84		Created     int64      `json:"created"`
 85		Description string     `json:"description"`
 86		Endpoints   []Endpoint `json:"endpoints"`
 87	} `json:"data"`
 88}
 89
 90// ModelsResponse is the response structure for the models API.
 91type ModelsResponse struct {
 92	Data []Model `json:"data"`
 93}
 94
 95// ModelPricing is the pricing structure for a model, detailing costs per
 96// million tokens for input and output, both cached and uncached.
 97type ModelPricing struct {
 98	CostPer1MIn        float64 `json:"cost_per_1m_in"`
 99	CostPer1MOut       float64 `json:"cost_per_1m_out"`
100	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
101	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
102}
103
104func getPricing(model Model) ModelPricing {
105	pricing := ModelPricing{}
106	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
107	if err != nil {
108		costPrompt = 0.0
109	}
110	pricing.CostPer1MIn = costPrompt * 1_000_000
111	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
112	if err != nil {
113		costCompletion = 0.0
114	}
115	pricing.CostPer1MOut = costCompletion * 1_000_000
116
117	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
118	if err != nil {
119		costPromptCached = 0.0
120	}
121	pricing.CostPer1MInCached = costPromptCached * 1_000_000
122	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
123	if err != nil {
124		costCompletionCached = 0.0
125	}
126	pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
127	return pricing
128}
129
130func fetchOpenRouterModels() (*ModelsResponse, error) {
131	client := &http.Client{Timeout: 30 * time.Second}
132	req, _ := http.NewRequestWithContext(
133		context.Background(),
134		"GET",
135		"https://openrouter.ai/api/v1/models",
136		nil,
137	)
138	req.Header.Set("User-Agent", "Crush-Client/1.0")
139	resp, err := client.Do(req)
140	if err != nil {
141		return nil, err //nolint:wrapcheck
142	}
143	defer resp.Body.Close() //nolint:errcheck
144	if resp.StatusCode != 200 {
145		body, _ := io.ReadAll(resp.Body)
146		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
147	}
148	var mr ModelsResponse
149	if err := json.NewDecoder(resp.Body).Decode(&mr); err != nil {
150		return nil, err //nolint:wrapcheck
151	}
152	return &mr, nil
153}
154
155func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
156	client := &http.Client{Timeout: 30 * time.Second}
157	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
158	req, _ := http.NewRequestWithContext(
159		context.Background(),
160		"GET",
161		url,
162		nil,
163	)
164	req.Header.Set("User-Agent", "Crush-Client/1.0")
165	resp, err := client.Do(req)
166	if err != nil {
167		return nil, err //nolint:wrapcheck
168	}
169	defer resp.Body.Close() //nolint:errcheck
170	if resp.StatusCode != 200 {
171		body, _ := io.ReadAll(resp.Body)
172		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
173	}
174	var er EndpointsResponse
175	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
176		return nil, err //nolint:wrapcheck
177	}
178	return &er, nil
179}
180
181func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
182	if len(endpoints) == 0 {
183		return nil
184	}
185
186	var best *Endpoint
187	for i := range endpoints {
188		endpoint := &endpoints[i]
189		// Skip endpoints with poor status or uptime
190		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
191			continue
192		}
193
194		if best == nil {
195			best = endpoint
196			continue
197		}
198
199		if isBetterEndpoint(endpoint, best) {
200			best = endpoint
201		}
202	}
203
204	// If no good endpoint found, return the first one as fallback
205	if best == nil {
206		best = &endpoints[0]
207	}
208
209	return best
210}
211
212func isBetterEndpoint(candidate, current *Endpoint) bool {
213	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
214	currentHasTools := slices.Contains(current.SupportedParams, "tools")
215
216	// Prefer endpoints with tool support over those without
217	if candidateHasTools && !currentHasTools {
218		return true
219	}
220	if !candidateHasTools && currentHasTools {
221		return false
222	}
223
224	// Both have same tool support status, compare other factors
225	if candidate.ContextLength > current.ContextLength {
226		return true
227	}
228	if candidate.ContextLength == current.ContextLength {
229		return candidate.UptimeLast30m > current.UptimeLast30m
230	}
231
232	return false
233}
234
235// This is used to generate the openrouter.json config file.
236func main() {
237	modelsResp, err := fetchOpenRouterModels()
238	if err != nil {
239		log.Fatal("Error fetching OpenRouter models:", err)
240	}
241
242	openRouterProvider := catwalk.Provider{
243		Name:                "OpenRouter",
244		ID:                  "openrouter",
245		APIKey:              "$OPENROUTER_API_KEY",
246		APIEndpoint:         "https://openrouter.ai/api/v1",
247		Type:                catwalk.TypeOpenAI,
248		DefaultLargeModelID: "anthropic/claude-sonnet-4",
249		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
250		Models:              []catwalk.Model{},
251		DefaultHeaders: map[string]string{
252			"HTTP-Referer": "https://charm.land",
253			"X-Title":      "Crush",
254		},
255	}
256
257	for _, model := range modelsResp.Data {
258		// skip nonβtext models or those without tools
259		if !slices.Contains(model.SupportedParams, "tools") ||
260			!slices.Contains(model.Architecture.InputModalities, "text") ||
261			!slices.Contains(model.Architecture.OutputModalities, "text") {
262			continue
263		}
264
265		// Fetch endpoints for this model to get the best configuration
266		endpointsResp, err := fetchModelEndpoints(model.ID)
267		if err != nil {
268			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
269			// Fall back to using the original model data
270			pricing := getPricing(model)
271			canReason := slices.Contains(model.SupportedParams, "reasoning")
272			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
273
274			m := catwalk.Model{
275				ID:                 model.ID,
276				Name:               model.Name,
277				CostPer1MIn:        pricing.CostPer1MIn,
278				CostPer1MOut:       pricing.CostPer1MOut,
279				CostPer1MInCached:  pricing.CostPer1MInCached,
280				CostPer1MOutCached: pricing.CostPer1MOutCached,
281				ContextWindow:      model.ContextLength,
282				CanReason:          canReason,
283				HasReasoningEffort: canReason,
284				SupportsImages:     supportsImages,
285			}
286			if model.TopProvider.MaxCompletionTokens != nil {
287				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
288			} else {
289				m.DefaultMaxTokens = model.ContextLength / 10
290			}
291			if model.TopProvider.ContextLength > 0 {
292				m.ContextWindow = model.TopProvider.ContextLength
293			}
294			openRouterProvider.Models = append(openRouterProvider.Models, m)
295			continue
296		}
297
298		// Select the best endpoint
299		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
300		if bestEndpoint == nil {
301			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
302			continue
303		}
304
305		// Check if the best endpoint supports tools
306		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
307			continue
308		}
309
310		// Use the best endpoint's configuration
311		pricing := ModelPricing{}
312		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
313		if err != nil {
314			costPrompt = 0.0
315		}
316		pricing.CostPer1MIn = costPrompt * 1_000_000
317		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
318		if err != nil {
319			costCompletion = 0.0
320		}
321		pricing.CostPer1MOut = costCompletion * 1_000_000
322
323		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
324		if err != nil {
325			costPromptCached = 0.0
326		}
327		pricing.CostPer1MInCached = costPromptCached * 1_000_000
328		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
329		if err != nil {
330			costCompletionCached = 0.0
331		}
332		pricing.CostPer1MOutCached = costCompletionCached * 1_000_000
333
334		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
335		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
336
337		m := catwalk.Model{
338			ID:                 model.ID,
339			Name:               model.Name,
340			CostPer1MIn:        pricing.CostPer1MIn,
341			CostPer1MOut:       pricing.CostPer1MOut,
342			CostPer1MInCached:  pricing.CostPer1MInCached,
343			CostPer1MOutCached: pricing.CostPer1MOutCached,
344			ContextWindow:      bestEndpoint.ContextLength,
345			CanReason:          canReason,
346			HasReasoningEffort: canReason,
347			SupportsImages:     supportsImages,
348		}
349
350		// Set max tokens based on the best endpoint
351		if bestEndpoint.MaxCompletionTokens != nil {
352			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
353		} else {
354			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
355		}
356
357		openRouterProvider.Models = append(openRouterProvider.Models, m)
358		fmt.Printf("Added model %s with context window %d from provider %s\n",
359			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
360	}
361
362	slices.SortFunc(openRouterProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
363		return strings.Compare(a.Name, b.Name)
364	})
365
366	// save the json in internal/providers/config/openrouter.json
367	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
368	if err != nil {
369		log.Fatal("Error marshaling OpenRouter provider:", err)
370	}
371	// write to file
372	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
373		log.Fatal("Error writing OpenRouter provider config:", err)
374	}
375}