main.go

  1// Package main provides a command-line tool to fetch models from OpenRouter
  2// and generate a configuration file for the provider.
  3package main
  4
  5import (
  6	"context"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"log"
 11	"math"
 12	"net/http"
 13	"os"
 14	"slices"
 15	"strconv"
 16	"strings"
 17	"time"
 18
 19	"charm.land/catwalk/pkg/catwalk"
 20)
 21
 22// Model represents the complete model configuration.
 23type Model struct {
 24	ID              string       `json:"id"`
 25	CanonicalSlug   string       `json:"canonical_slug"`
 26	HuggingFaceID   string       `json:"hugging_face_id"`
 27	Name            string       `json:"name"`
 28	Created         int64        `json:"created"`
 29	Description     string       `json:"description"`
 30	ContextLength   int64        `json:"context_length"`
 31	Architecture    Architecture `json:"architecture"`
 32	Pricing         Pricing      `json:"pricing"`
 33	TopProvider     TopProvider  `json:"top_provider"`
 34	SupportedParams []string     `json:"supported_parameters"`
 35}
 36
 37// Architecture defines the model's architecture details.
 38type Architecture struct {
 39	Modality         string   `json:"modality"`
 40	InputModalities  []string `json:"input_modalities"`
 41	OutputModalities []string `json:"output_modalities"`
 42	Tokenizer        string   `json:"tokenizer"`
 43	InstructType     *string  `json:"instruct_type"`
 44}
 45
 46// Pricing contains the pricing information for different operations.
 47type Pricing struct {
 48	Prompt            string `json:"prompt"`
 49	Completion        string `json:"completion"`
 50	Request           string `json:"request"`
 51	Image             string `json:"image"`
 52	WebSearch         string `json:"web_search"`
 53	InternalReasoning string `json:"internal_reasoning"`
 54	InputCacheRead    string `json:"input_cache_read"`
 55	InputCacheWrite   string `json:"input_cache_write"`
 56}
 57
 58// TopProvider describes the top provider's capabilities.
 59type TopProvider struct {
 60	ContextLength       int64  `json:"context_length"`
 61	MaxCompletionTokens *int64 `json:"max_completion_tokens"`
 62	IsModerated         bool   `json:"is_moderated"`
 63}
 64
 65// Endpoint represents a single endpoint configuration for a model.
 66type Endpoint struct {
 67	Name                string   `json:"name"`
 68	ContextLength       int64    `json:"context_length"`
 69	Pricing             Pricing  `json:"pricing"`
 70	ProviderName        string   `json:"provider_name"`
 71	Tag                 string   `json:"tag"`
 72	Quantization        *string  `json:"quantization"`
 73	MaxCompletionTokens *int64   `json:"max_completion_tokens"`
 74	MaxPromptTokens     *int64   `json:"max_prompt_tokens"`
 75	SupportedParams     []string `json:"supported_parameters"`
 76	Status              int      `json:"status"`
 77	UptimeLast30m       float64  `json:"uptime_last_30m"`
 78}
 79
 80// EndpointsResponse is the response structure for the endpoints API.
 81type EndpointsResponse struct {
 82	Data struct {
 83		ID          string     `json:"id"`
 84		Name        string     `json:"name"`
 85		Created     int64      `json:"created"`
 86		Description string     `json:"description"`
 87		Endpoints   []Endpoint `json:"endpoints"`
 88	} `json:"data"`
 89}
 90
 91// ModelsResponse is the response structure for the models API.
 92type ModelsResponse struct {
 93	Data []Model `json:"data"`
 94}
 95
 96// ModelPricing is the pricing structure for a model, detailing costs per
 97// million tokens for input and output, both cached and uncached.
 98type ModelPricing struct {
 99	CostPer1MIn        float64 `json:"cost_per_1m_in"`
100	CostPer1MOut       float64 `json:"cost_per_1m_out"`
101	CostPer1MInCached  float64 `json:"cost_per_1m_in_cached"`
102	CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
103}
104
105func roundCost(v float64) float64 {
106	return math.Round(v*1e5) / 1e5
107}
108
109func getPricing(model Model) ModelPricing {
110	pricing := ModelPricing{}
111	costPrompt, err := strconv.ParseFloat(model.Pricing.Prompt, 64)
112	if err != nil {
113		costPrompt = 0.0
114	}
115	pricing.CostPer1MIn = roundCost(costPrompt * 1_000_000)
116	costCompletion, err := strconv.ParseFloat(model.Pricing.Completion, 64)
117	if err != nil {
118		costCompletion = 0.0
119	}
120	pricing.CostPer1MOut = roundCost(costCompletion * 1_000_000)
121
122	costPromptCached, err := strconv.ParseFloat(model.Pricing.InputCacheWrite, 64)
123	if err != nil {
124		costPromptCached = 0.0
125	}
126	pricing.CostPer1MInCached = roundCost(costPromptCached * 1_000_000)
127	costCompletionCached, err := strconv.ParseFloat(model.Pricing.InputCacheRead, 64)
128	if err != nil {
129		costCompletionCached = 0.0
130	}
131	pricing.CostPer1MOutCached = roundCost(costCompletionCached * 1_000_000)
132	return pricing
133}
134
135func fetchOpenRouterModels() (*ModelsResponse, error) {
136	client := &http.Client{Timeout: 30 * time.Second}
137	req, _ := http.NewRequestWithContext(
138		context.Background(),
139		"GET",
140		"https://openrouter.ai/api/v1/models",
141		nil,
142	)
143	req.Header.Set("User-Agent", "Crush-Client/1.0")
144
145	resp, err := client.Do(req)
146	if err != nil {
147		return nil, err //nolint:wrapcheck
148	}
149	defer resp.Body.Close() //nolint:errcheck
150
151	body, err := io.ReadAll(resp.Body)
152	if err != nil {
153		return nil, fmt.Errorf("unable to read models response body: %w", err)
154	}
155
156	if resp.StatusCode != http.StatusOK {
157		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
158	}
159
160	// for debugging
161	_ = os.MkdirAll("tmp", 0o700)
162	_ = os.WriteFile("tmp/openrouter-response.json", body, 0o600)
163
164	var mr ModelsResponse
165	if err := json.Unmarshal(body, &mr); err != nil {
166		return nil, err //nolint:wrapcheck
167	}
168	return &mr, nil
169}
170
171func fetchModelEndpoints(modelID string) (*EndpointsResponse, error) {
172	client := &http.Client{Timeout: 30 * time.Second}
173	url := fmt.Sprintf("https://openrouter.ai/api/v1/models/%s/endpoints", modelID)
174	req, _ := http.NewRequestWithContext(
175		context.Background(),
176		"GET",
177		url,
178		nil,
179	)
180	req.Header.Set("User-Agent", "Crush-Client/1.0")
181	resp, err := client.Do(req)
182	if err != nil {
183		return nil, err //nolint:wrapcheck
184	}
185	defer resp.Body.Close() //nolint:errcheck
186	if resp.StatusCode != 200 {
187		body, _ := io.ReadAll(resp.Body)
188		return nil, fmt.Errorf("status %d: %s", resp.StatusCode, body)
189	}
190	var er EndpointsResponse
191	if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
192		return nil, err //nolint:wrapcheck
193	}
194	return &er, nil
195}
196
197func selectBestEndpoint(endpoints []Endpoint) *Endpoint {
198	if len(endpoints) == 0 {
199		return nil
200	}
201
202	var best *Endpoint
203	for i := range endpoints {
204		endpoint := &endpoints[i]
205		// Skip endpoints with poor status or uptime
206		if endpoint.Status < 0 || endpoint.UptimeLast30m < 90.0 {
207			continue
208		}
209
210		if best == nil {
211			best = endpoint
212			continue
213		}
214
215		if isBetterEndpoint(endpoint, best) {
216			best = endpoint
217		}
218	}
219
220	// If no good endpoint found, return the first one as fallback
221	if best == nil {
222		best = &endpoints[0]
223	}
224
225	return best
226}
227
228func isBetterEndpoint(candidate, current *Endpoint) bool {
229	candidateHasTools := slices.Contains(candidate.SupportedParams, "tools")
230	currentHasTools := slices.Contains(current.SupportedParams, "tools")
231
232	// Prefer endpoints with tool support over those without
233	if candidateHasTools && !currentHasTools {
234		return true
235	}
236	if !candidateHasTools && currentHasTools {
237		return false
238	}
239
240	// Both have same tool support status, compare other factors
241	if candidate.ContextLength > current.ContextLength {
242		return true
243	}
244	if candidate.ContextLength == current.ContextLength {
245		return candidate.UptimeLast30m > current.UptimeLast30m
246	}
247
248	return false
249}
250
251// This is used to generate the openrouter.json config file.
252func main() {
253	modelsResp, err := fetchOpenRouterModels()
254	if err != nil {
255		log.Fatal("Error fetching OpenRouter models:", err)
256	}
257
258	openRouterProvider := catwalk.Provider{
259		Name:                "OpenRouter",
260		ID:                  "openrouter",
261		APIKey:              "$OPENROUTER_API_KEY",
262		APIEndpoint:         "https://openrouter.ai/api/v1",
263		Type:                catwalk.TypeOpenRouter,
264		DefaultLargeModelID: "anthropic/claude-sonnet-4",
265		DefaultSmallModelID: "anthropic/claude-3.5-haiku",
266		Models:              []catwalk.Model{},
267		DefaultHeaders: map[string]string{
268			"HTTP-Referer": "https://charm.land",
269			"X-Title":      "Crush",
270		},
271	}
272
273	for _, model := range modelsResp.Data {
274		if model.ContextLength < 20000 {
275			continue
276		}
277		// skip non‐text models or those without tools
278		if !slices.Contains(model.SupportedParams, "tools") ||
279			!slices.Contains(model.Architecture.InputModalities, "text") ||
280			!slices.Contains(model.Architecture.OutputModalities, "text") {
281			continue
282		}
283
284		// Fetch endpoints for this model to get the best configuration
285		endpointsResp, err := fetchModelEndpoints(model.ID)
286		if err != nil {
287			fmt.Printf("Warning: Failed to fetch endpoints for %s: %v\n", model.ID, err)
288			// Fall back to using the original model data
289			pricing := getPricing(model)
290			canReason := slices.Contains(model.SupportedParams, "reasoning")
291			supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
292
293			var reasoningLevels []string
294			var defaultReasoning string
295			if canReason {
296				reasoningLevels = []string{"low", "medium", "high"}
297				defaultReasoning = "medium"
298			}
299			m := catwalk.Model{
300				ID:                     model.ID,
301				Name:                   model.Name,
302				CostPer1MIn:            pricing.CostPer1MIn,
303				CostPer1MOut:           pricing.CostPer1MOut,
304				CostPer1MInCached:      pricing.CostPer1MInCached,
305				CostPer1MOutCached:     pricing.CostPer1MOutCached,
306				ContextWindow:          model.ContextLength,
307				CanReason:              canReason,
308				DefaultReasoningEffort: defaultReasoning,
309				ReasoningLevels:        reasoningLevels,
310				SupportsImages:         supportsImages,
311			}
312			if model.TopProvider.MaxCompletionTokens != nil {
313				m.DefaultMaxTokens = *model.TopProvider.MaxCompletionTokens / 2
314			} else {
315				m.DefaultMaxTokens = model.ContextLength / 10
316			}
317			if model.TopProvider.ContextLength > 0 {
318				m.ContextWindow = model.TopProvider.ContextLength
319			}
320			openRouterProvider.Models = append(openRouterProvider.Models, m)
321			continue
322		}
323
324		// Select the best endpoint
325		bestEndpoint := selectBestEndpoint(endpointsResp.Data.Endpoints)
326		if bestEndpoint == nil {
327			fmt.Printf("Warning: No suitable endpoint found for %s\n", model.ID)
328			continue
329		}
330
331		// Check if the best endpoint supports tools
332		if !slices.Contains(bestEndpoint.SupportedParams, "tools") {
333			continue
334		}
335
336		// Use the best endpoint's configuration
337		pricing := ModelPricing{}
338		costPrompt, err := strconv.ParseFloat(bestEndpoint.Pricing.Prompt, 64)
339		if err != nil {
340			costPrompt = 0.0
341		}
342		pricing.CostPer1MIn = roundCost(costPrompt * 1_000_000)
343		costCompletion, err := strconv.ParseFloat(bestEndpoint.Pricing.Completion, 64)
344		if err != nil {
345			costCompletion = 0.0
346		}
347		pricing.CostPer1MOut = roundCost(costCompletion * 1_000_000)
348
349		costPromptCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheWrite, 64)
350		if err != nil {
351			costPromptCached = 0.0
352		}
353		pricing.CostPer1MInCached = roundCost(costPromptCached * 1_000_000)
354		costCompletionCached, err := strconv.ParseFloat(bestEndpoint.Pricing.InputCacheRead, 64)
355		if err != nil {
356			costCompletionCached = 0.0
357		}
358		pricing.CostPer1MOutCached = roundCost(costCompletionCached * 1_000_000)
359
360		canReason := slices.Contains(bestEndpoint.SupportedParams, "reasoning")
361		supportsImages := slices.Contains(model.Architecture.InputModalities, "image")
362
363		var reasoningLevels []string
364		var defaultReasoning string
365		if canReason {
366			reasoningLevels = []string{"low", "medium", "high"}
367			defaultReasoning = "medium"
368		}
369		m := catwalk.Model{
370			ID:                     model.ID,
371			Name:                   model.Name,
372			CostPer1MIn:            pricing.CostPer1MIn,
373			CostPer1MOut:           pricing.CostPer1MOut,
374			CostPer1MInCached:      pricing.CostPer1MInCached,
375			CostPer1MOutCached:     pricing.CostPer1MOutCached,
376			ContextWindow:          bestEndpoint.ContextLength,
377			CanReason:              canReason,
378			DefaultReasoningEffort: defaultReasoning,
379			ReasoningLevels:        reasoningLevels,
380			SupportsImages:         supportsImages,
381		}
382
383		// Set max tokens based on the best endpoint
384		if bestEndpoint.MaxCompletionTokens != nil {
385			m.DefaultMaxTokens = *bestEndpoint.MaxCompletionTokens / 2
386		} else {
387			m.DefaultMaxTokens = bestEndpoint.ContextLength / 10
388		}
389
390		openRouterProvider.Models = append(openRouterProvider.Models, m)
391		fmt.Printf("Added model %s with context window %d from provider %s\n",
392			model.ID, bestEndpoint.ContextLength, bestEndpoint.ProviderName)
393	}
394
395	slices.SortFunc(openRouterProvider.Models, func(a catwalk.Model, b catwalk.Model) int {
396		return strings.Compare(a.Name, b.Name)
397	})
398
399	// save the json in internal/providers/config/openrouter.json
400	data, err := json.MarshalIndent(openRouterProvider, "", "  ")
401	if err != nil {
402		log.Fatal("Error marshaling OpenRouter provider:", err)
403	}
404	// write to file
405	if err := os.WriteFile("internal/providers/configs/openrouter.json", data, 0o600); err != nil {
406		log.Fatal("Error writing OpenRouter provider config:", err)
407	}
408}